Skip to content


Building blocks that can be shared across all models.

Masked (Module)

Wrapper to mix normal modules with others that take 2 inputs

Source code in thunder/
class Masked(nn.Module):
    """Wrapper to mix normal modules with others that take 2 inputs"""

    def __init__(self, *layers):
        self.layer = nn.Sequential(*layers)

    def forward(
        self, audio: torch.Tensor, audio_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.layer(audio), audio_lengths

forward(self, audio, audio_lengths)

Defines the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in thunder/
def forward(
    self, audio: torch.Tensor, audio_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    return self.layer(audio), audio_lengths

MultiSequential (Sequential)

nn.Sequential equivalent with 2 inputs/outputs

Source code in thunder/
class MultiSequential(nn.Sequential):
    """nn.Sequential equivalent with 2 inputs/outputs"""

    def forward(
        self, audio: torch.Tensor, audio_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        for module in self.children():
            audio, audio_lengths = module(audio, audio_lengths)
        return audio, audio_lengths

forward(self, audio, audio_lengths)

Defines the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in thunder/
def forward(
    self, audio: torch.Tensor, audio_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    for module in self.children():
        audio, audio_lengths = module(audio, audio_lengths)
    return audio, audio_lengths

SwapLastDimension (Module)

Layer that swap the last two dimensions of the data.

Source code in thunder/
class SwapLastDimension(nn.Module):
    """Layer that swap the last two dimensions of the data."""

    def forward(self, x: Tensor) -> Tensor:
        return x.transpose(-1, -2)

forward(self, x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in thunder/
def forward(self, x: Tensor) -> Tensor:
    return x.transpose(-1, -2)

conv1d_decoder(decoder_input_channels, num_classes)

Decoder that uses one conv1d layer


Name Type Description Default
num_classes int

Number of output classes of the model. It's the size of the vocabulary, excluding the blank symbol.

decoder_input_channels int

Number of input channels of the decoder. That is the number of channels of the features created by the encoder.



Type Description

Pytorch model of the decoder

Source code in thunder/
def conv1d_decoder(decoder_input_channels: int, num_classes: int) -> nn.Module:
    """Decoder that uses one conv1d layer

        num_classes: Number of output classes of the model. It's the size of the vocabulary, excluding the blank symbol.
        decoder_input_channels: Number of input channels of the decoder. That is the number of channels of the features created by the encoder.

        Pytorch model of the decoder
    decoder = nn.Conv1d(
    nn.init.xavier_uniform_(decoder.weight, gain=1.0)
    return decoder

convolution_stft(input_data, n_fft=1024, hop_length=512, win_length=1024, window=tensor([0.0000e+00, 9.4175e-06, 3.7730e-05, ..., 3.7730e-05, 9.4175e-06,0.0000e+00]), center=True, return_complex=False)

Implements the stft operation using the convolution method. This is one alternative to make possible to export code using this operation to onnx and arm based environments. The signature shuld follow the same as torch.stft, making it possible to just swap the two. The code is based on

Source code in thunder/
def convolution_stft(
    input_data: torch.Tensor,
    n_fft: int = 1024,
    hop_length: int = 512,
    win_length: int = 1024,
    window: torch.Tensor = torch.hann_window(1024, periodic=False),
    center: bool = True,
    return_complex: bool = False,
) -> torch.Tensor:
    """Implements the stft operation using the convolution method. This is one alternative
    to make possible to export code using this operation to onnx and arm based environments.
    The signature shuld follow the same as torch.stft, making it possible to just swap the two.
    The code is based on
    assert n_fft >= win_length
    pad_amount = int(n_fft / 2)
    window =

    fourier_basis = _fourier_matrix(n_fft, device=input_data.device)

    cutoff = int((n_fft / 2 + 1))
    fourier_basis = torch.stack(
        [torch.real(fourier_basis[:cutoff, :]), torch.imag(fourier_basis[:cutoff, :])]
    ).reshape(-1, n_fft)
    forward_basis = fourier_basis[:, None, :].float()

    window_pad = (n_fft - win_length) // 2
    window_pad2 = n_fft - (window_pad + win_length)
    fft_window = torch.nn.functional.pad(window, [window_pad, window_pad2])
    # window the bases
    forward_basis *= fft_window
    forward_basis = forward_basis.float()

    num_batches = input_data.shape[0]
    num_samples = input_data.shape[-1]

    # similar to librosa, reflect-pad the input
    input_data = input_data.view(num_batches, 1, num_samples)

    input_data = F.pad(
        (pad_amount, pad_amount, 0, 0),
    input_data = input_data.squeeze(1)

    forward_transform = F.conv1d(
        input_data, forward_basis, stride=hop_length, padding=0

    cutoff = int((n_fft / 2) + 1)
    real_part = forward_transform[:, :cutoff, :]
    imag_part = forward_transform[:, cutoff:, :]
    return torch.stack((real_part, imag_part), dim=-1)

get_same_padding(kernel_size, stride, dilation)

Calculates the padding size to obtain same padding. Same padding means that the output will have the shape input_shape / stride. That means, for stride = 1 the output shape is the same as the input, and stride = 2 gives an output that is half of the input shape.


Name Type Description Default
kernel_size int

convolution kernel size. Only tested to be correct with odd values.

stride int

convolution stride

dilation int

convolution dilation



Type Description

Only stride or dilation may be greater than 1


Type Description

padding value to obtain same padding.

Source code in thunder/
def get_same_padding(kernel_size: int, stride: int, dilation: int) -> int:
    """Calculates the padding size to obtain same padding.
        Same padding means that the output will have the
        shape input_shape / stride. That means, for
        stride = 1 the output shape is the same as the input,
        and stride = 2 gives an output that is half of the
        input shape.

        kernel_size: convolution kernel size. Only tested to be correct with odd values.
        stride: convolution stride
        dilation: convolution dilation

        ValueError: Only stride or dilation may be greater than 1

        padding value to obtain same padding.
    if stride > 1 and dilation > 1:
        raise ValueError("Only stride OR dilation may be greater than 1")
    if dilation > 1:
        return (dilation * (kernel_size - 1) + 1) // 2
    return kernel_size // 2

lengths_to_mask(lengths, max_length)

Convert from integer lengths of each element to mask representation


Name Type Description Default
lengths Tensor

lengths of each element in the batch

max_length int

maximum length expected. Can be greater than lengths.max()



Type Description

Corresponding boolean mask indicating the valid region of the tensor.

Source code in thunder/
def lengths_to_mask(lengths: torch.Tensor, max_length: int) -> torch.Tensor:
    """Convert from integer lengths of each element to mask representation

        lengths: lengths of each element in the batch
        max_length: maximum length expected. Can be greater than lengths.max()

        Corresponding boolean mask indicating the valid region of the tensor.
    lengths = lengths.type(torch.long)
    mask = torch.arange(max_length, device=lengths.device).expand(
        lengths.shape[0], max_length
    ) < lengths.unsqueeze(1)
    return mask

linear_decoder(decoder_input_channels, num_classes, decoder_dropout)

Decoder that uses a linear layer with dropout


Name Type Description Default
decoder_dropout float

Amount of dropout to be used in the decoder

decoder_input_channels int

Number of input channels of the decoder. That is the number of channels of the features created by the encoder.

num_classes int

Number of output classes of the model. It's the size of the vocabulary, excluding the blank symbol.



Type Description

Module that represents the decoder.

Source code in thunder/
def linear_decoder(
    decoder_input_channels: int, num_classes: int, decoder_dropout: float
) -> nn.Module:
    """Decoder that uses a linear layer with dropout

        decoder_dropout: Amount of dropout to be used in the decoder
        decoder_input_channels: Number of input channels of the decoder. That is the number of channels of the features created by the encoder.
        num_classes: Number of output classes of the model. It's the size of the vocabulary, excluding the blank symbol.

        Module that represents the decoder.

    # SwapLastDimension is necessary to
    # change from (batch, time, #vocab) to (batch, #vocab, time)
    # that is expected by the rest of the library
    return nn.Sequential(
        nn.Linear(decoder_input_channels, num_classes),

normalize_tensor(input_values, mask=None, div_guard=1e-07, dim=-1)

Normalize tensor values, optionally using some mask to define the valid region.


Name Type Description Default
input_values Tensor

input tensor to be normalized

mask Optional[torch.Tensor]

Optional mask describing the valid elements.

div_guard float

value used to prevent division by zero when normalizing.

dim int

dimension used to calculate the mean and variance.



Type Description

Normalized tensor

Source code in thunder/
def normalize_tensor(
    input_values: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
    div_guard: float = 1e-7,
    dim: int = -1,
) -> torch.Tensor:
    """Normalize tensor values, optionally using some mask to define the valid region.

        input_values: input tensor to be normalized
        mask: Optional mask describing the valid elements.
        div_guard: value used to prevent division by zero when normalizing.
        dim: dimension used to calculate the mean and variance.

        Normalized tensor
    # Vectorized implementation of (x - x.mean()) / x.std() considering only the valid mask
    if mask is not None:
        # Making sure the elements outside the mask are zero, to have the correct mean/std
        input_values = torch.masked_fill(input_values, ~mask.type(torch.bool), 0.0)
        # Number of valid elements
        num_elements = mask.sum(dim=dim, keepdim=True).detach()
        # Mean is sum over number of elements
        x_mean = input_values.sum(dim=dim, keepdim=True).detach() / num_elements
        # std numerator: sum of squared differences to the mean
        numerator = (input_values - x_mean).pow(2).sum(dim=dim, keepdim=True).detach()
        x_std = (numerator / num_elements).sqrt()
        # using the div_guard to prevent division by zero
        normalized = (input_values - x_mean) / (x_std + div_guard)
        # Cleaning elements outside of valid mask
        return torch.masked_fill(normalized, ~mask.type(torch.bool), 0.0)

    mean = input_values.mean(dim=dim, keepdim=True).detach()
    std = (input_values.var(dim=dim, keepdim=True).detach() + div_guard).sqrt()
    return (input_values - mean) / std