CTC Loss
Functionality to calculate the ctc loss.
calculate_ctc(probabilities, y, prob_lengths, y_lengths, blank_idx)
Calculates the ctc loss based on model probabilities (also called emissions) and labels.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
probabilities |
Tensor |
Output of the model, before any softmax operation. Shape [batch, #vocab, time] |
required |
y |
Tensor |
Tensor containing the corresponding labels. Shape [batch] |
required |
prob_lengths |
Tensor |
Lengths of each element in the input. Shape [batch] |
required |
y_lengths |
Tensor |
Lenghts of each element in the output. Should NOT be normalized. |
required |
blank_idx |
int |
Index of the blank token in the vocab. |
required |
Returns:
Type | Description |
---|---|
Tensor |
Loss tensor that can be backpropagated. |
Source code in thunder/ctc_loss.py
def calculate_ctc(
probabilities: Tensor,
y: Tensor,
prob_lengths: Tensor,
y_lengths: Tensor,
blank_idx: int,
) -> Tensor:
"""Calculates the ctc loss based on model probabilities (also called emissions) and
labels.
Args:
probabilities: Output of the model, before any softmax operation. Shape [batch, #vocab, time]
y: Tensor containing the corresponding labels. Shape [batch]
prob_lengths: Lengths of each element in the input. Shape [batch]
y_lengths: Lenghts of each element in the output. Should NOT be normalized.
blank_idx: Index of the blank token in the vocab.
Returns:
Loss tensor that can be backpropagated.
"""
# Change from (batch, #vocab, time) to (time, batch, #vocab)
probabilities = probabilities.permute(2, 0, 1)
logprobs = log_softmax(probabilities, dim=2)
return ctc_loss(
logprobs,
y,
prob_lengths.long(),
y_lengths,
blank=blank_idx,
reduction="mean",
zero_infinity=True,
)