Skip to content

CTC Loss

Functionality to calculate the ctc loss.

calculate_ctc(probabilities, y, prob_lengths, y_lengths, blank_idx)

Calculates the ctc loss based on model probabilities (also called emissions) and labels.

Parameters:

Name Type Description Default
probabilities Tensor

Output of the model, before any softmax operation. Shape [batch, #vocab, time]

required
y Tensor

Tensor containing the corresponding labels. Shape [batch]

required
prob_lengths Tensor

Lengths of each element in the input. Shape [batch]

required
y_lengths Tensor

Lenghts of each element in the output. Should NOT be normalized.

required
blank_idx int

Index of the blank token in the vocab.

required

Returns:

Type Description
Tensor

Loss tensor that can be backpropagated.

Source code in thunder/ctc_loss.py
def calculate_ctc(
    probabilities: Tensor,
    y: Tensor,
    prob_lengths: Tensor,
    y_lengths: Tensor,
    blank_idx: int,
) -> Tensor:
    """Calculates the ctc loss based on model probabilities (also called emissions) and
    labels.

    Args:
        probabilities: Output of the model, before any softmax operation. Shape [batch, #vocab, time]
        y: Tensor containing the corresponding labels. Shape [batch]
        prob_lengths: Lengths of each element in the input. Shape [batch]
        y_lengths: Lenghts of each element in the output. Should NOT be normalized.
        blank_idx: Index of the blank token in the vocab.

    Returns:
        Loss tensor that can be backpropagated.
    """
    # Change from (batch, #vocab, time) to (time, batch, #vocab)
    probabilities = probabilities.permute(2, 0, 1)
    logprobs = log_softmax(probabilities, dim=2)

    return ctc_loss(
        logprobs,
        y,
        prob_lengths.long(),
        y_lengths,
        blank=blank_idx,
        reduction="mean",
        zero_infinity=True,
    )