Blocks
Basic building blocks to create the Citrinet model
CitrinetBlock (Module)
Source code in thunder/citrinet/blocks.py
class CitrinetBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
repeat: int = 5,
kernel_size: _size_1_t = (11,),
stride: _size_1_t = (1,),
dilation: _size_1_t = (1,),
dropout: float = 0.0,
residual: bool = True,
separable: bool = False,
):
"""Citrinet block. This is a refactoring of the Jasperblock present on the NeMo toolkit,
but simplified to only support the new citrinet model. Biggest change is that
dense residual used on Jasper is not supported here.
Args:
in_channels: Number of input channels
out_channels: Number of output channels
repeat: Repetitions inside block.
kernel_size: Kernel size.
stride: Stride of each repetition.
dilation: Dilation of each repetition.
dropout: Dropout used before each activation.
residual: Controls the use of residual connection.
separable: Controls the use of separable convolutions.
"""
super().__init__()
padding_val = get_same_padding(kernel_size[0], 1, dilation[0])
inplanes_loop = in_channels
conv = []
for _ in range(repeat - 1):
conv.extend(
_get_conv_bn_layer(
inplanes_loop,
out_channels,
kernel_size=kernel_size,
stride=(1,), # Only stride the last one
dilation=dilation,
padding=padding_val,
separable=separable,
bias=False,
)
)
conv.extend(_get_act_dropout_layer(drop_prob=dropout))
inplanes_loop = out_channels
padding_val = get_same_padding(kernel_size[0], stride[0], dilation[0])
conv.extend(
_get_conv_bn_layer(
inplanes_loop,
out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding_val,
separable=separable,
bias=False,
)
)
conv.append(Masked(SqueezeExcite(out_channels, reduction_ratio=8)))
self.mconv = MultiSequential(*conv)
if residual:
stride_residual = stride if stride[0] == 1 else stride[0]
self.res = MultiSequential(
*_get_conv_bn_layer(
in_channels,
out_channels,
kernel_size=1,
stride=stride_residual,
bias=False,
)
)
else:
self.res = None
self.mout = MultiSequential(*_get_act_dropout_layer(drop_prob=dropout))
def forward(
self, x: torch.Tensor, lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Tensor of shape (batch, features, time) where #features == inplanes
lengths: corresponding length of each element in the input tensor.
Returns:
Result of applying the block on the input, and corresponding output lengths
"""
# compute forward convolutions
out, lengths_out = self.mconv(x, lengths)
# compute the residuals
if self.res is not None:
res_out, _ = self.res(x, lengths)
out = out + res_out
# compute the output
out, lengths_out = self.mout(out, lengths_out)
return out, lengths_out
__init__(self, in_channels, out_channels, repeat=5, kernel_size=(11,), stride=(1,), dilation=(1,), dropout=0.0, residual=True, separable=False)
special
Citrinet block. This is a refactoring of the Jasperblock present on the NeMo toolkit, but simplified to only support the new citrinet model. Biggest change is that dense residual used on Jasper is not supported here.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
in_channels |
int |
Number of input channels |
required |
out_channels |
int |
Number of output channels |
required |
repeat |
int |
Repetitions inside block. |
5 |
kernel_size |
Union[int, Tuple[int]] |
Kernel size. |
(11,) |
stride |
Union[int, Tuple[int]] |
Stride of each repetition. |
(1,) |
dilation |
Union[int, Tuple[int]] |
Dilation of each repetition. |
(1,) |
dropout |
float |
Dropout used before each activation. |
0.0 |
residual |
bool |
Controls the use of residual connection. |
True |
separable |
bool |
Controls the use of separable convolutions. |
False |
Source code in thunder/citrinet/blocks.py
def __init__(
self,
in_channels: int,
out_channels: int,
repeat: int = 5,
kernel_size: _size_1_t = (11,),
stride: _size_1_t = (1,),
dilation: _size_1_t = (1,),
dropout: float = 0.0,
residual: bool = True,
separable: bool = False,
):
"""Citrinet block. This is a refactoring of the Jasperblock present on the NeMo toolkit,
but simplified to only support the new citrinet model. Biggest change is that
dense residual used on Jasper is not supported here.
Args:
in_channels: Number of input channels
out_channels: Number of output channels
repeat: Repetitions inside block.
kernel_size: Kernel size.
stride: Stride of each repetition.
dilation: Dilation of each repetition.
dropout: Dropout used before each activation.
residual: Controls the use of residual connection.
separable: Controls the use of separable convolutions.
"""
super().__init__()
padding_val = get_same_padding(kernel_size[0], 1, dilation[0])
inplanes_loop = in_channels
conv = []
for _ in range(repeat - 1):
conv.extend(
_get_conv_bn_layer(
inplanes_loop,
out_channels,
kernel_size=kernel_size,
stride=(1,), # Only stride the last one
dilation=dilation,
padding=padding_val,
separable=separable,
bias=False,
)
)
conv.extend(_get_act_dropout_layer(drop_prob=dropout))
inplanes_loop = out_channels
padding_val = get_same_padding(kernel_size[0], stride[0], dilation[0])
conv.extend(
_get_conv_bn_layer(
inplanes_loop,
out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding_val,
separable=separable,
bias=False,
)
)
conv.append(Masked(SqueezeExcite(out_channels, reduction_ratio=8)))
self.mconv = MultiSequential(*conv)
if residual:
stride_residual = stride if stride[0] == 1 else stride[0]
self.res = MultiSequential(
*_get_conv_bn_layer(
in_channels,
out_channels,
kernel_size=1,
stride=stride_residual,
bias=False,
)
)
else:
self.res = None
self.mout = MultiSequential(*_get_act_dropout_layer(drop_prob=dropout))
forward(self, x, lengths)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Tensor |
Tensor of shape (batch, features, time) where #features == inplanes |
required |
lengths |
Tensor |
corresponding length of each element in the input tensor. |
required |
Returns:
Type | Description |
---|---|
Tuple[torch.Tensor, torch.Tensor] |
Result of applying the block on the input, and corresponding output lengths |
Source code in thunder/citrinet/blocks.py
def forward(
self, x: torch.Tensor, lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Tensor of shape (batch, features, time) where #features == inplanes
lengths: corresponding length of each element in the input tensor.
Returns:
Result of applying the block on the input, and corresponding output lengths
"""
# compute forward convolutions
out, lengths_out = self.mconv(x, lengths)
# compute the residuals
if self.res is not None:
res_out, _ = self.res(x, lengths)
out = out + res_out
# compute the output
out, lengths_out = self.mout(out, lengths_out)
return out, lengths_out
SqueezeExcite (Module)
Source code in thunder/citrinet/blocks.py
class SqueezeExcite(nn.Module):
def __init__(
self,
channels: int,
reduction_ratio: int,
):
"""
Squeeze-and-Excitation sub-module.
Args:
channels: Input number of channels.
reduction_ratio: Reduction ratio for "squeeze" layer.
"""
super().__init__()
self.pool = nn.AdaptiveAvgPool1d(1) # context window = T
self.fc = nn.Sequential(
nn.Linear(channels, channels // reduction_ratio, bias=False),
nn.ReLU(True),
nn.Linear(channels // reduction_ratio, channels, bias=False),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Tensor of shape [batch, channels, time]
Returns:
Tensor of shape [batch, channels, time]
"""
y = self.pool(x) # [B, C, T - context_window + 1]
y = y.transpose(1, -1) # [B, T - context_window + 1, C]
y = self.fc(y) # [B, T - context_window + 1, C]
y = y.transpose(1, -1) # [B, C, T - context_window + 1]
y = torch.sigmoid(y)
return x * y
__init__(self, channels, reduction_ratio)
special
Squeeze-and-Excitation sub-module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
channels |
int |
Input number of channels. |
required |
reduction_ratio |
int |
Reduction ratio for "squeeze" layer. |
required |
Source code in thunder/citrinet/blocks.py
def __init__(
self,
channels: int,
reduction_ratio: int,
):
"""
Squeeze-and-Excitation sub-module.
Args:
channels: Input number of channels.
reduction_ratio: Reduction ratio for "squeeze" layer.
"""
super().__init__()
self.pool = nn.AdaptiveAvgPool1d(1) # context window = T
self.fc = nn.Sequential(
nn.Linear(channels, channels // reduction_ratio, bias=False),
nn.ReLU(True),
nn.Linear(channels // reduction_ratio, channels, bias=False),
)
forward(self, x)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Tensor |
Tensor of shape [batch, channels, time] |
required |
Returns:
Type | Description |
---|---|
Tensor |
Tensor of shape [batch, channels, time] |
Source code in thunder/citrinet/blocks.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Tensor of shape [batch, channels, time]
Returns:
Tensor of shape [batch, channels, time]
"""
y = self.pool(x) # [B, C, T - context_window + 1]
y = y.transpose(1, -1) # [B, T - context_window + 1, C]
y = self.fc(y) # [B, T - context_window + 1, C]
y = y.transpose(1, -1) # [B, C, T - context_window + 1]
y = torch.sigmoid(y)
return x * y
CitrinetEncoder(filters, kernel_sizes, strides, feat_in=80, dropout=0.0)
Basic Citrinet encoder setup.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
filters |
List[int] |
List of filter sizes used to create the encoder blocks. |
required |
kernel_sizes |
List[int] |
List of kernel sizes corresponding to each filter size. |
required |
strides |
List[int] |
List of stride corresponding to each filter size. |
required |
feat_in |
int |
Number of input features to the model. |
80 |
Returns:
Type | Description |
---|---|
Module |
Pytorch model corresponding to the encoder. |
Source code in thunder/citrinet/blocks.py
def CitrinetEncoder(
filters: List[int],
kernel_sizes: List[int],
strides: List[int],
feat_in: int = 80,
dropout: float = 0.0,
) -> nn.Module:
"""Basic Citrinet encoder setup.
Args:
filters: List of filter sizes used to create the encoder blocks.
kernel_sizes: List of kernel sizes corresponding to each filter size.
strides: List of stride corresponding to each filter size.
feat_in: Number of input features to the model.
Returns:
Pytorch model corresponding to the encoder.
"""
return MultiSequential(
stem(feat_in),
*body(filters, kernel_sizes, strides, dropout),
)
body(filters, kernel_size, strides, dropout=0.0)
Creates the body of the Citrinet model. That is the middle part.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
filters |
List[int] |
List of filters inside each block in the body. |
required |
kernel_size |
List[int] |
Corresponding list of kernel sizes for each block. Should have the same length as the first argument. |
required |
strides |
List[int] |
Corresponding list of strides for each block. Should have the same length as the first argument. |
required |
Returns:
Type | Description |
---|---|
List[thunder.citrinet.blocks.CitrinetBlock] |
List of layers that form the body of the network. |
Source code in thunder/citrinet/blocks.py
def body(
filters: List[int],
kernel_size: List[int],
strides: List[int],
dropout: float = 0.0,
) -> List[CitrinetBlock]:
"""Creates the body of the Citrinet model. That is the middle part.
Args:
filters: List of filters inside each block in the body.
kernel_size: Corresponding list of kernel sizes for each block. Should have the same length as the first argument.
strides: Corresponding list of strides for each block. Should have the same length as the first argument.
Returns:
List of layers that form the body of the network.
"""
layers = []
f_in = 256
for f, k, s in zip(filters, kernel_size, strides):
layers.append(
CitrinetBlock(
f_in, f, kernel_size=(k,), stride=(s,), separable=True, dropout=dropout
)
)
f_in = f
layers.append(
CitrinetBlock(
f_in,
640,
repeat=1,
kernel_size=(41,),
residual=False,
separable=True,
dropout=dropout,
)
)
return layers
stem(feat_in)
Creates the Citrinet stem. That is the first block of the model, that process the input directly.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
feat_in |
int |
Number of input features |
required |
Returns:
Type | Description |
---|---|
CitrinetBlock |
Citrinet stem block |
Source code in thunder/citrinet/blocks.py
def stem(feat_in: int) -> CitrinetBlock:
"""Creates the Citrinet stem. That is the first block of the model, that process the input directly.
Args:
feat_in: Number of input features
Returns:
Citrinet stem block
"""
return CitrinetBlock(
feat_in,
256,
repeat=1,
kernel_size=(5,),
residual=False,
separable=True,
)