Transform
Functionality to transform the audio input in the same way that the Quartznet model expects it.
DitherAudio (Module)
Source code in thunder/quartznet/transform.py
class DitherAudio(nn.Module):
def __init__(self, dither: float = 1e-5):
"""Add some dithering to the audio tensor.
Note:
From wikipedia: Dither is an intentionally applied
form of noise used to randomize quantization error.
Args:
dither: Amount of dither to add.
"""
super().__init__()
self.dither = dither
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Tensor of shape (batch, time)
"""
if self.training:
return x + (self.dither * torch.randn_like(x))
else:
return x
__init__(self, dither=1e-05)
special
Add some dithering to the audio tensor.
Note
From wikipedia: Dither is an intentionally applied form of noise used to randomize quantization error.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dither |
float |
Amount of dither to add. |
1e-05 |
Source code in thunder/quartznet/transform.py
def __init__(self, dither: float = 1e-5):
"""Add some dithering to the audio tensor.
Note:
From wikipedia: Dither is an intentionally applied
form of noise used to randomize quantization error.
Args:
dither: Amount of dither to add.
"""
super().__init__()
self.dither = dither
forward(self, x)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Tensor |
Tensor of shape (batch, time) |
required |
Source code in thunder/quartznet/transform.py
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Tensor of shape (batch, time)
"""
if self.training:
return x + (self.dither * torch.randn_like(x))
else:
return x
FeatureBatchNormalizer (Module)
Source code in thunder/quartznet/transform.py
class FeatureBatchNormalizer(nn.Module):
def __init__(self):
"""Normalize batch at the feature dimension."""
super().__init__()
self.div_guard = 1e-5
def forward(
self, x: torch.Tensor, lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Tensor of shape (batch, features, time)
lengths: corresponding length of each element in the input tensor.
"""
# https://github.com/pytorch/pytorch/issues/45208
# https://github.com/pytorch/pytorch/issues/44768
with torch.no_grad():
mask = lengths_to_mask(lengths, x.shape[-1])
return (
normalize_tensor(x, mask.unsqueeze(1), div_guard=self.div_guard),
lengths,
)
__init__(self)
special
Normalize batch at the feature dimension.
Source code in thunder/quartznet/transform.py
def __init__(self):
"""Normalize batch at the feature dimension."""
super().__init__()
self.div_guard = 1e-5
forward(self, x, lengths)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Tensor |
Tensor of shape (batch, features, time) |
required |
lengths |
Tensor |
corresponding length of each element in the input tensor. |
required |
Source code in thunder/quartznet/transform.py
def forward(
self, x: torch.Tensor, lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Tensor of shape (batch, features, time)
lengths: corresponding length of each element in the input tensor.
"""
# https://github.com/pytorch/pytorch/issues/45208
# https://github.com/pytorch/pytorch/issues/44768
with torch.no_grad():
mask = lengths_to_mask(lengths, x.shape[-1])
return (
normalize_tensor(x, mask.unsqueeze(1), div_guard=self.div_guard),
lengths,
)
MelScale (Module)
Source code in thunder/quartznet/transform.py
class MelScale(nn.Module):
def __init__(
self, sample_rate: int, n_fft: int, nfilt: int, log_scale: bool = True
):
"""Convert a spectrogram to Mel scale, following the default
formula of librosa instead of the one used by torchaudio.
Also converts to log scale.
Args:
sample_rate: Sampling rate of the signal
n_fft: Number of fourier features
nfilt: Number of output mel filters to use
log_scale: Controls if the output should also be applied a log scale.
"""
super().__init__()
filterbanks = (
melscale_fbanks(
int(1 + n_fft // 2),
n_mels=nfilt,
sample_rate=sample_rate,
f_min=0,
f_max=sample_rate / 2,
norm="slaney",
mel_scale="slaney",
)
.transpose(0, 1)
.unsqueeze(0)
)
self.register_buffer("fb", filterbanks)
self.log_scale = log_scale
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Tensor of shape (batch, features, time)
"""
# dot with filterbank energies
x = torch.matmul(self.fb.to(x.dtype), x)
# log features
# We want to avoid taking the log of zero
if self.log_scale:
x = torch.log(x + 2**-24)
return x
__init__(self, sample_rate, n_fft, nfilt, log_scale=True)
special
Convert a spectrogram to Mel scale, following the default formula of librosa instead of the one used by torchaudio. Also converts to log scale.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sample_rate |
int |
Sampling rate of the signal |
required |
n_fft |
int |
Number of fourier features |
required |
nfilt |
int |
Number of output mel filters to use |
required |
log_scale |
bool |
Controls if the output should also be applied a log scale. |
True |
Source code in thunder/quartznet/transform.py
def __init__(
self, sample_rate: int, n_fft: int, nfilt: int, log_scale: bool = True
):
"""Convert a spectrogram to Mel scale, following the default
formula of librosa instead of the one used by torchaudio.
Also converts to log scale.
Args:
sample_rate: Sampling rate of the signal
n_fft: Number of fourier features
nfilt: Number of output mel filters to use
log_scale: Controls if the output should also be applied a log scale.
"""
super().__init__()
filterbanks = (
melscale_fbanks(
int(1 + n_fft // 2),
n_mels=nfilt,
sample_rate=sample_rate,
f_min=0,
f_max=sample_rate / 2,
norm="slaney",
mel_scale="slaney",
)
.transpose(0, 1)
.unsqueeze(0)
)
self.register_buffer("fb", filterbanks)
self.log_scale = log_scale
forward(self, x)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Tensor |
Tensor of shape (batch, features, time) |
required |
Source code in thunder/quartznet/transform.py
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Tensor of shape (batch, features, time)
"""
# dot with filterbank energies
x = torch.matmul(self.fb.to(x.dtype), x)
# log features
# We want to avoid taking the log of zero
if self.log_scale:
x = torch.log(x + 2**-24)
return x
PowerSpectrum (Module)
Source code in thunder/quartznet/transform.py
class PowerSpectrum(nn.Module):
def __init__(
self,
n_window_size: int = 320,
n_window_stride: int = 160,
n_fft: Optional[int] = None,
):
"""Calculates the power spectrum of the audio signal, following the same
method as used in NEMO.
Args:
n_window_size: Number of elements in the window size.
n_window_stride: Number of elements in the window stride.
n_fft: Number of fourier features.
Raises:
ValueError: Raised when incompatible parameters are passed.
"""
super().__init__()
if n_window_size <= 0 or n_window_stride <= 0:
raise ValueError(
f"{self} got an invalid value for either n_window_size or "
f"n_window_stride. Both must be positive ints."
)
self.win_length = n_window_size
self.hop_length = n_window_stride
self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
window_tensor = torch.hann_window(self.win_length, periodic=False)
self.register_buffer("window", window_tensor)
# This way so that the torch.stft can be changed to the patched version
# before scripting. That way it works correctly when the export option
# doesnt support fft, like mobile or onnx.
self.stft_func = torch.stft
def get_sequence_length(self, lengths: torch.Tensor) -> torch.Tensor:
seq_len = torch.floor(lengths / self.hop_length) + 1
return seq_len.to(dtype=torch.long)
@torch.no_grad()
def forward(
self, x: torch.Tensor, lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Tensor of shape (batch, time)
"""
x = self.stft_func(
x,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
center=True,
window=self.window.to(dtype=torch.float),
return_complex=False,
)
# torch returns real, imag; so convert to magnitude
x = torch.sqrt(x.pow(2).sum(-1))
# get power spectrum
x = x.pow(2.0)
return x, self.get_sequence_length(lengths)
__init__(self, n_window_size=320, n_window_stride=160, n_fft=None)
special
Calculates the power spectrum of the audio signal, following the same method as used in NEMO.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_window_size |
int |
Number of elements in the window size. |
320 |
n_window_stride |
int |
Number of elements in the window stride. |
160 |
n_fft |
Optional[int] |
Number of fourier features. |
None |
Exceptions:
Type | Description |
---|---|
ValueError |
Raised when incompatible parameters are passed. |
Source code in thunder/quartznet/transform.py
def __init__(
self,
n_window_size: int = 320,
n_window_stride: int = 160,
n_fft: Optional[int] = None,
):
"""Calculates the power spectrum of the audio signal, following the same
method as used in NEMO.
Args:
n_window_size: Number of elements in the window size.
n_window_stride: Number of elements in the window stride.
n_fft: Number of fourier features.
Raises:
ValueError: Raised when incompatible parameters are passed.
"""
super().__init__()
if n_window_size <= 0 or n_window_stride <= 0:
raise ValueError(
f"{self} got an invalid value for either n_window_size or "
f"n_window_stride. Both must be positive ints."
)
self.win_length = n_window_size
self.hop_length = n_window_stride
self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
window_tensor = torch.hann_window(self.win_length, periodic=False)
self.register_buffer("window", window_tensor)
# This way so that the torch.stft can be changed to the patched version
# before scripting. That way it works correctly when the export option
# doesnt support fft, like mobile or onnx.
self.stft_func = torch.stft
forward(self, x, lengths)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Tensor |
Tensor of shape (batch, time) |
required |
Source code in thunder/quartznet/transform.py
@torch.no_grad()
def forward(
self, x: torch.Tensor, lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Tensor of shape (batch, time)
"""
x = self.stft_func(
x,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
center=True,
window=self.window.to(dtype=torch.float),
return_complex=False,
)
# torch returns real, imag; so convert to magnitude
x = torch.sqrt(x.pow(2).sum(-1))
# get power spectrum
x = x.pow(2.0)
return x, self.get_sequence_length(lengths)
PreEmphasisFilter (Module)
Source code in thunder/quartznet/transform.py
class PreEmphasisFilter(nn.Module):
def __init__(self, preemph: float = 0.97):
"""Applies preemphasis filtering to the audio signal.
This is a classic signal processing function to emphasise
the high frequency portion of the content compared to the
low frequency. It applies a FIR filter of the form:
`y[n] = y[n] - preemph * y[n-1]`
Args:
preemph: Filter control factor.
"""
super().__init__()
self.preemph = preemph
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Tensor of shape (batch, time)
"""
return torch.cat(
(x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1
)
__init__(self, preemph=0.97)
special
Applies preemphasis filtering to the audio signal. This is a classic signal processing function to emphasise the high frequency portion of the content compared to the low frequency. It applies a FIR filter of the form:
y[n] = y[n] - preemph * y[n-1]
Parameters:
Name | Type | Description | Default |
---|---|---|---|
preemph |
float |
Filter control factor. |
0.97 |
Source code in thunder/quartznet/transform.py
def __init__(self, preemph: float = 0.97):
"""Applies preemphasis filtering to the audio signal.
This is a classic signal processing function to emphasise
the high frequency portion of the content compared to the
low frequency. It applies a FIR filter of the form:
`y[n] = y[n] - preemph * y[n-1]`
Args:
preemph: Filter control factor.
"""
super().__init__()
self.preemph = preemph
forward(self, x)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Tensor |
Tensor of shape (batch, time) |
required |
Source code in thunder/quartznet/transform.py
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Tensor of shape (batch, time)
"""
return torch.cat(
(x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1
)
FilterbankFeatures(sample_rate=16000, n_window_size=320, n_window_stride=160, n_fft=512, preemph=0.97, nfilt=64, dither=1e-05, num_cutout_masks=0, num_time_masks=0, num_freq_masks=0, mask_time_width=50, mask_freq_width=20)
Creates the Filterbank features used in the Quartznet model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sample_rate |
int |
Sampling rate of the signal. |
16000 |
n_window_size |
int |
Number of elements in the window size. |
320 |
n_window_stride |
int |
Number of elements in the window stride. |
160 |
n_fft |
int |
Number of fourier features. |
512 |
preemph |
float |
Preemphasis filtering control factor. |
0.97 |
nfilt |
int |
Number of output mel filters to use. |
64 |
dither |
float |
Amount of dither to add. |
1e-05 |
Returns:
Type | Description |
---|---|
Module |
Module that computes the features based on raw audio tensor. |
Source code in thunder/quartznet/transform.py
def FilterbankFeatures(
sample_rate: int = 16000,
n_window_size: int = 320,
n_window_stride: int = 160,
n_fft: int = 512,
preemph: float = 0.97,
nfilt: int = 64,
dither: float = 1e-5,
num_cutout_masks: int = 0,
num_time_masks: int = 0,
num_freq_masks: int = 0,
mask_time_width: int = 50,
mask_freq_width: int = 20,
) -> nn.Module:
"""Creates the Filterbank features used in the Quartznet model.
Args:
sample_rate: Sampling rate of the signal.
n_window_size: Number of elements in the window size.
n_window_stride: Number of elements in the window stride.
n_fft: Number of fourier features.
preemph: Preemphasis filtering control factor.
nfilt: Number of output mel filters to use.
dither: Amount of dither to add.
Returns:
Module that computes the features based on raw audio tensor.
"""
if num_cutout_masks > 0 and (num_freq_masks + num_time_masks > 0):
raise ValueError("Cutout and SpecAugment can't be used at the same time.")
base_modules = [
Masked(DitherAudio(dither=dither), PreEmphasisFilter(preemph=preemph)),
PowerSpectrum(
n_window_size=n_window_size,
n_window_stride=n_window_stride,
n_fft=n_fft,
),
Masked(MelScale(sample_rate=sample_rate, n_fft=n_fft, nfilt=nfilt)),
FeatureBatchNormalizer(),
]
if num_cutout_masks > 0:
base_modules.append(
Masked(
SpecCutout(
rect_masks=num_cutout_masks,
time_width=mask_time_width,
freq_width=mask_freq_width,
)
)
)
if num_freq_masks + num_time_masks > 0:
base_modules.append(
Masked(
SpecAugment(
time_masks=num_time_masks,
freq_masks=num_freq_masks,
time_width=mask_time_width,
freq_width=mask_freq_width,
)
)
)
return MultiSequential(*base_modules)
patch_stft(filterbank)
This function applies a patch to the FilterbankFeatures to use instead a convolution layer based stft. That makes possible to export to onnx and use the scripted model directly on arm cpu's, inside mobile applications.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
filterbank |
Module |
the FilterbankFeatures layer to be patched |
required |
Returns:
Type | Description |
---|---|
Module |
Layer with the stft operation patched. |
Source code in thunder/quartznet/transform.py
def patch_stft(filterbank: nn.Module) -> nn.Module:
"""This function applies a patch to the FilterbankFeatures to use instead a convolution
layer based stft. That makes possible to export to onnx and use the scripted model
directly on arm cpu's, inside mobile applications.
Args:
filterbank: the FilterbankFeatures layer to be patched
Returns:
Layer with the stft operation patched.
"""
filterbank[1].stft_func = convolution_stft
return filterbank