Skip to content

Transform

Functionality to transform the audio input in the same way that the Quartznet model expects it.

DitherAudio (Module)

Source code in thunder/quartznet/transform.py
class DitherAudio(nn.Module):
    def __init__(self, dither: float = 1e-5):
        """Add some dithering to the audio tensor.

        Note:
            From wikipedia: Dither is an intentionally applied
            form of noise used to randomize quantization error.

        Args:
            dither: Amount of dither to add.
        """
        super().__init__()
        self.dither = dither

    @torch.no_grad()
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (batch, time)
        """
        if self.training:
            return x + (self.dither * torch.randn_like(x))
        else:
            return x

__init__(self, dither=1e-05) special

Add some dithering to the audio tensor.

Note

From wikipedia: Dither is an intentionally applied form of noise used to randomize quantization error.

Parameters:

Name Type Description Default
dither float

Amount of dither to add.

1e-05
Source code in thunder/quartznet/transform.py
def __init__(self, dither: float = 1e-5):
    """Add some dithering to the audio tensor.

    Note:
        From wikipedia: Dither is an intentionally applied
        form of noise used to randomize quantization error.

    Args:
        dither: Amount of dither to add.
    """
    super().__init__()
    self.dither = dither

forward(self, x)

Parameters:

Name Type Description Default
x Tensor

Tensor of shape (batch, time)

required
Source code in thunder/quartznet/transform.py
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Args:
        x: Tensor of shape (batch, time)
    """
    if self.training:
        return x + (self.dither * torch.randn_like(x))
    else:
        return x

FeatureBatchNormalizer (Module)

Source code in thunder/quartznet/transform.py
class FeatureBatchNormalizer(nn.Module):
    def __init__(self):
        """Normalize batch at the feature dimension."""
        super().__init__()
        self.div_guard = 1e-5

    def forward(
        self, x: torch.Tensor, lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: Tensor of shape (batch, features, time)
            lengths: corresponding length of each element in the input tensor.
        """
        # https://github.com/pytorch/pytorch/issues/45208
        # https://github.com/pytorch/pytorch/issues/44768
        with torch.no_grad():
            mask = lengths_to_mask(lengths, x.shape[-1])
            return (
                normalize_tensor(x, mask.unsqueeze(1), div_guard=self.div_guard),
                lengths,
            )

__init__(self) special

Normalize batch at the feature dimension.

Source code in thunder/quartznet/transform.py
def __init__(self):
    """Normalize batch at the feature dimension."""
    super().__init__()
    self.div_guard = 1e-5

forward(self, x, lengths)

Parameters:

Name Type Description Default
x Tensor

Tensor of shape (batch, features, time)

required
lengths Tensor

corresponding length of each element in the input tensor.

required
Source code in thunder/quartznet/transform.py
def forward(
    self, x: torch.Tensor, lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Args:
        x: Tensor of shape (batch, features, time)
        lengths: corresponding length of each element in the input tensor.
    """
    # https://github.com/pytorch/pytorch/issues/45208
    # https://github.com/pytorch/pytorch/issues/44768
    with torch.no_grad():
        mask = lengths_to_mask(lengths, x.shape[-1])
        return (
            normalize_tensor(x, mask.unsqueeze(1), div_guard=self.div_guard),
            lengths,
        )

MelScale (Module)

Source code in thunder/quartznet/transform.py
class MelScale(nn.Module):
    def __init__(
        self, sample_rate: int, n_fft: int, nfilt: int, log_scale: bool = True
    ):
        """Convert a spectrogram to Mel scale, following the default
        formula of librosa instead of the one used by torchaudio.
        Also converts to log scale.

        Args:
            sample_rate: Sampling rate of the signal
            n_fft: Number of fourier features
            nfilt: Number of output mel filters to use
            log_scale: Controls if the output should also be applied a log scale.
        """
        super().__init__()

        filterbanks = (
            melscale_fbanks(
                int(1 + n_fft // 2),
                n_mels=nfilt,
                sample_rate=sample_rate,
                f_min=0,
                f_max=sample_rate / 2,
                norm="slaney",
                mel_scale="slaney",
            )
            .transpose(0, 1)
            .unsqueeze(0)
        )
        self.register_buffer("fb", filterbanks)
        self.log_scale = log_scale

    @torch.no_grad()
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (batch, features, time)
        """
        # dot with filterbank energies
        x = torch.matmul(self.fb.to(x.dtype), x)
        # log features
        # We want to avoid taking the log of zero
        if self.log_scale:
            x = torch.log(x + 2**-24)
        return x

__init__(self, sample_rate, n_fft, nfilt, log_scale=True) special

Convert a spectrogram to Mel scale, following the default formula of librosa instead of the one used by torchaudio. Also converts to log scale.

Parameters:

Name Type Description Default
sample_rate int

Sampling rate of the signal

required
n_fft int

Number of fourier features

required
nfilt int

Number of output mel filters to use

required
log_scale bool

Controls if the output should also be applied a log scale.

True
Source code in thunder/quartznet/transform.py
def __init__(
    self, sample_rate: int, n_fft: int, nfilt: int, log_scale: bool = True
):
    """Convert a spectrogram to Mel scale, following the default
    formula of librosa instead of the one used by torchaudio.
    Also converts to log scale.

    Args:
        sample_rate: Sampling rate of the signal
        n_fft: Number of fourier features
        nfilt: Number of output mel filters to use
        log_scale: Controls if the output should also be applied a log scale.
    """
    super().__init__()

    filterbanks = (
        melscale_fbanks(
            int(1 + n_fft // 2),
            n_mels=nfilt,
            sample_rate=sample_rate,
            f_min=0,
            f_max=sample_rate / 2,
            norm="slaney",
            mel_scale="slaney",
        )
        .transpose(0, 1)
        .unsqueeze(0)
    )
    self.register_buffer("fb", filterbanks)
    self.log_scale = log_scale

forward(self, x)

Parameters:

Name Type Description Default
x Tensor

Tensor of shape (batch, features, time)

required
Source code in thunder/quartznet/transform.py
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Args:
        x: Tensor of shape (batch, features, time)
    """
    # dot with filterbank energies
    x = torch.matmul(self.fb.to(x.dtype), x)
    # log features
    # We want to avoid taking the log of zero
    if self.log_scale:
        x = torch.log(x + 2**-24)
    return x

PowerSpectrum (Module)

Source code in thunder/quartznet/transform.py
class PowerSpectrum(nn.Module):
    def __init__(
        self,
        n_window_size: int = 320,
        n_window_stride: int = 160,
        n_fft: Optional[int] = None,
    ):
        """Calculates the power spectrum of the audio signal, following the same
        method as used in NEMO.

        Args:
            n_window_size: Number of elements in the window size.
            n_window_stride: Number of elements in the window stride.
            n_fft: Number of fourier features.

        Raises:
            ValueError: Raised when incompatible parameters are passed.
        """
        super().__init__()
        if n_window_size <= 0 or n_window_stride <= 0:
            raise ValueError(
                f"{self} got an invalid value for either n_window_size or "
                f"n_window_stride. Both must be positive ints."
            )
        self.win_length = n_window_size
        self.hop_length = n_window_stride
        self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))

        window_tensor = torch.hann_window(self.win_length, periodic=False)
        self.register_buffer("window", window_tensor)
        # This way so that the torch.stft can be changed to the patched version
        # before scripting. That way it works correctly when the export option
        # doesnt support fft, like mobile or onnx.
        self.stft_func = torch.stft

    def get_sequence_length(self, lengths: torch.Tensor) -> torch.Tensor:
        seq_len = torch.floor(lengths / self.hop_length) + 1
        return seq_len.to(dtype=torch.long)

    @torch.no_grad()
    def forward(
        self, x: torch.Tensor, lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: Tensor of shape (batch, time)
        """
        x = self.stft_func(
            x,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            center=True,
            window=self.window.to(dtype=torch.float),
            return_complex=False,
        )

        # torch returns real, imag; so convert to magnitude
        x = torch.sqrt(x.pow(2).sum(-1))
        # get power spectrum
        x = x.pow(2.0)
        return x, self.get_sequence_length(lengths)

__init__(self, n_window_size=320, n_window_stride=160, n_fft=None) special

Calculates the power spectrum of the audio signal, following the same method as used in NEMO.

Parameters:

Name Type Description Default
n_window_size int

Number of elements in the window size.

320
n_window_stride int

Number of elements in the window stride.

160
n_fft Optional[int]

Number of fourier features.

None

Exceptions:

Type Description
ValueError

Raised when incompatible parameters are passed.

Source code in thunder/quartznet/transform.py
def __init__(
    self,
    n_window_size: int = 320,
    n_window_stride: int = 160,
    n_fft: Optional[int] = None,
):
    """Calculates the power spectrum of the audio signal, following the same
    method as used in NEMO.

    Args:
        n_window_size: Number of elements in the window size.
        n_window_stride: Number of elements in the window stride.
        n_fft: Number of fourier features.

    Raises:
        ValueError: Raised when incompatible parameters are passed.
    """
    super().__init__()
    if n_window_size <= 0 or n_window_stride <= 0:
        raise ValueError(
            f"{self} got an invalid value for either n_window_size or "
            f"n_window_stride. Both must be positive ints."
        )
    self.win_length = n_window_size
    self.hop_length = n_window_stride
    self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))

    window_tensor = torch.hann_window(self.win_length, periodic=False)
    self.register_buffer("window", window_tensor)
    # This way so that the torch.stft can be changed to the patched version
    # before scripting. That way it works correctly when the export option
    # doesnt support fft, like mobile or onnx.
    self.stft_func = torch.stft

forward(self, x, lengths)

Parameters:

Name Type Description Default
x Tensor

Tensor of shape (batch, time)

required
Source code in thunder/quartznet/transform.py
@torch.no_grad()
def forward(
    self, x: torch.Tensor, lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Args:
        x: Tensor of shape (batch, time)
    """
    x = self.stft_func(
        x,
        n_fft=self.n_fft,
        hop_length=self.hop_length,
        win_length=self.win_length,
        center=True,
        window=self.window.to(dtype=torch.float),
        return_complex=False,
    )

    # torch returns real, imag; so convert to magnitude
    x = torch.sqrt(x.pow(2).sum(-1))
    # get power spectrum
    x = x.pow(2.0)
    return x, self.get_sequence_length(lengths)

PreEmphasisFilter (Module)

Source code in thunder/quartznet/transform.py
class PreEmphasisFilter(nn.Module):
    def __init__(self, preemph: float = 0.97):
        """Applies preemphasis filtering to the audio signal.
        This is a classic signal processing function to emphasise
        the high frequency portion of the content compared to the
        low frequency. It applies a FIR filter of the form:

        `y[n] = y[n] - preemph * y[n-1]`

        Args:
            preemph: Filter control factor.
        """
        super().__init__()
        self.preemph = preemph

    @torch.no_grad()
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (batch, time)
        """
        return torch.cat(
            (x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1
        )

__init__(self, preemph=0.97) special

Applies preemphasis filtering to the audio signal. This is a classic signal processing function to emphasise the high frequency portion of the content compared to the low frequency. It applies a FIR filter of the form:

y[n] = y[n] - preemph * y[n-1]

Parameters:

Name Type Description Default
preemph float

Filter control factor.

0.97
Source code in thunder/quartznet/transform.py
def __init__(self, preemph: float = 0.97):
    """Applies preemphasis filtering to the audio signal.
    This is a classic signal processing function to emphasise
    the high frequency portion of the content compared to the
    low frequency. It applies a FIR filter of the form:

    `y[n] = y[n] - preemph * y[n-1]`

    Args:
        preemph: Filter control factor.
    """
    super().__init__()
    self.preemph = preemph

forward(self, x)

Parameters:

Name Type Description Default
x Tensor

Tensor of shape (batch, time)

required
Source code in thunder/quartznet/transform.py
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Args:
        x: Tensor of shape (batch, time)
    """
    return torch.cat(
        (x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1
    )

FilterbankFeatures(sample_rate=16000, n_window_size=320, n_window_stride=160, n_fft=512, preemph=0.97, nfilt=64, dither=1e-05, num_cutout_masks=0, num_time_masks=0, num_freq_masks=0, mask_time_width=50, mask_freq_width=20)

Creates the Filterbank features used in the Quartznet model.

Parameters:

Name Type Description Default
sample_rate int

Sampling rate of the signal.

16000
n_window_size int

Number of elements in the window size.

320
n_window_stride int

Number of elements in the window stride.

160
n_fft int

Number of fourier features.

512
preemph float

Preemphasis filtering control factor.

0.97
nfilt int

Number of output mel filters to use.

64
dither float

Amount of dither to add.

1e-05

Returns:

Type Description
Module

Module that computes the features based on raw audio tensor.

Source code in thunder/quartznet/transform.py
def FilterbankFeatures(
    sample_rate: int = 16000,
    n_window_size: int = 320,
    n_window_stride: int = 160,
    n_fft: int = 512,
    preemph: float = 0.97,
    nfilt: int = 64,
    dither: float = 1e-5,
    num_cutout_masks: int = 0,
    num_time_masks: int = 0,
    num_freq_masks: int = 0,
    mask_time_width: int = 50,
    mask_freq_width: int = 20,
) -> nn.Module:
    """Creates the Filterbank features used in the Quartznet model.

    Args:
        sample_rate: Sampling rate of the signal.
        n_window_size: Number of elements in the window size.
        n_window_stride: Number of elements in the window stride.
        n_fft: Number of fourier features.
        preemph: Preemphasis filtering control factor.
        nfilt: Number of output mel filters to use.
        dither: Amount of dither to add.
    Returns:
        Module that computes the features based on raw audio tensor.
    """
    if num_cutout_masks > 0 and (num_freq_masks + num_time_masks > 0):
        raise ValueError("Cutout and SpecAugment can't be used at the same time.")

    base_modules = [
        Masked(DitherAudio(dither=dither), PreEmphasisFilter(preemph=preemph)),
        PowerSpectrum(
            n_window_size=n_window_size,
            n_window_stride=n_window_stride,
            n_fft=n_fft,
        ),
        Masked(MelScale(sample_rate=sample_rate, n_fft=n_fft, nfilt=nfilt)),
        FeatureBatchNormalizer(),
    ]

    if num_cutout_masks > 0:
        base_modules.append(
            Masked(
                SpecCutout(
                    rect_masks=num_cutout_masks,
                    time_width=mask_time_width,
                    freq_width=mask_freq_width,
                )
            )
        )
    if num_freq_masks + num_time_masks > 0:
        base_modules.append(
            Masked(
                SpecAugment(
                    time_masks=num_time_masks,
                    freq_masks=num_freq_masks,
                    time_width=mask_time_width,
                    freq_width=mask_freq_width,
                )
            )
        )

    return MultiSequential(*base_modules)

patch_stft(filterbank)

This function applies a patch to the FilterbankFeatures to use instead a convolution layer based stft. That makes possible to export to onnx and use the scripted model directly on arm cpu's, inside mobile applications.

Parameters:

Name Type Description Default
filterbank Module

the FilterbankFeatures layer to be patched

required

Returns:

Type Description
Module

Layer with the stft operation patched.

Source code in thunder/quartznet/transform.py
def patch_stft(filterbank: nn.Module) -> nn.Module:
    """This function applies a patch to the FilterbankFeatures to use instead a convolution
    layer based stft. That makes possible to export to onnx and use the scripted model
    directly on arm cpu's, inside mobile applications.

    Args:
        filterbank: the FilterbankFeatures layer to be patched

    Returns:
        Layer with the stft operation patched.
    """
    filterbank[1].stft_func = convolution_stft
    return filterbank