Compatibility
Helper functions to load the Citrinet model from original Nemo released checkpoint files.
CitrinetCheckpoint (BaseCheckpoint)
Trained model weight checkpoints.
Used by download_checkpoint
and
load_citrinet_checkpoint
.
Note
Possible values are stt_en_citrinet_256
,stt_en_citrinet_512
,stt_en_citrinet_1024
, stt_es_citrinet_512
Source code in thunder/citrinet/compatibility.py
class CitrinetCheckpoint(BaseCheckpoint):
"""Trained model weight checkpoints.
Used by [`download_checkpoint`][thunder.utils.download_checkpoint] and
[`load_citrinet_checkpoint`][thunder.citrinet.compatibility.load_citrinet_checkpoint].
Note:
Possible values are `stt_en_citrinet_256`,`stt_en_citrinet_512`,`stt_en_citrinet_1024`, `stt_es_citrinet_512`
"""
stt_en_citrinet_256 = "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_citrinet_256/versions/1.0.0rc1/files/stt_en_citrinet_256.nemo"
stt_en_citrinet_512 = "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_citrinet_512/versions/1.0.0rc1/files/stt_en_citrinet_512.nemo"
stt_en_citrinet_1024 = "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_citrinet_1024/versions/1.0.0rc1/files/stt_en_citrinet_1024.nemo"
stt_es_citrinet_512 = "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_es_citrinet_512/versions/1.0.0/files/stt_es_citrinet_512.nemo"
fix_vocab(vocab_tokens)
Transform the nemo vocab tokens back to the sentencepiece sytle with the _ prefix
Parameters:
Name | Type | Description | Default |
---|---|---|---|
vocab_tokens |
List[str] |
List of tokens in the vocabulary |
required |
Returns:
Type | Description |
---|---|
List[str] |
New list of tokens with the new prefix |
Source code in thunder/citrinet/compatibility.py
def fix_vocab(vocab_tokens: List[str]) -> List[str]:
"""Transform the nemo vocab tokens back to the sentencepiece sytle
with the _ prefix
Args:
vocab_tokens: List of tokens in the vocabulary
Returns:
New list of tokens with the new prefix
"""
out_tokens = []
for token in vocab_tokens:
if token.startswith("##"):
out_tokens.append(token[2:])
else:
out_tokens.append("▁" + token)
return out_tokens
load_citrinet_checkpoint(checkpoint, save_folder=None, augment_params=None)
Load from the original nemo checkpoint.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
checkpoint |
Union[str, thunder.citrinet.compatibility.CitrinetCheckpoint] |
Path to local .nemo file or checkpoint to be downloaded locally and lodaded. |
required |
save_folder |
str |
Path to save the checkpoint when downloading it. Ignored if you pass a .nemo file as the first argument. |
None |
Returns:
Type | Description |
---|---|
BaseCTCModule |
The model loaded from the checkpoint |
Source code in thunder/citrinet/compatibility.py
def load_citrinet_checkpoint(
checkpoint: Union[str, CitrinetCheckpoint],
save_folder: str = None,
augment_params: AugmentParams = None,
) -> BaseCTCModule:
"""Load from the original nemo checkpoint.
Args:
checkpoint: Path to local .nemo file or checkpoint to be downloaded locally and lodaded.
save_folder: Path to save the checkpoint when downloading it. Ignored if you pass a .nemo file as the first argument.
Returns:
The model loaded from the checkpoint
"""
if isinstance(checkpoint, CitrinetCheckpoint):
nemo_filepath = download_checkpoint(checkpoint, save_folder)
else:
nemo_filepath = checkpoint
with TemporaryDirectory() as extract_folder:
extract_archive(str(nemo_filepath), extract_folder)
extract_path = Path(extract_folder)
config_path = extract_path / "model_config.yaml"
sentencepiece_path = str(extract_path / "tokenizer.model")
(
encoder,
audio_transform,
text_transform,
) = load_components_from_citrinet_config(
config_path, sentencepiece_path, augment_params
)
decoder = conv1d_decoder(640, num_classes=text_transform.num_tokens)
weights_path = extract_path / "model_weights.ckpt"
load_quartznet_weights(encoder, decoder, str(weights_path))
module = BaseCTCModule(
encoder,
decoder,
audio_transform,
text_transform,
encoder_final_dimension=640,
)
return module.eval()
load_components_from_citrinet_config(config_path, sentencepiece_path, augment_params=None)
Read the important parameters from the config stored inside the .nemo checkpoint.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config_path |
Union[str, pathlib.Path] |
Path to the .yaml file, usually called model_config.yaml |
required |
sentencepiece_path |
Union[str, pathlib.Path] |
Path to the sentencepiece model used to tokenize, usually called tokenizer.model |
required |
Returns:
Type | Description |
---|---|
Tuple[torch.nn.modules.module.Module, torch.nn.modules.module.Module, thunder.text_processing.transform.BatchTextTransformer] |
A tuple containing, in this order, the encoder, the audio transform and the text transform |
Source code in thunder/citrinet/compatibility.py
def load_components_from_citrinet_config(
config_path: Union[str, Path],
sentencepiece_path: Union[str, Path],
augment_params: AugmentParams = None,
) -> Tuple[nn.Module, nn.Module, BatchTextTransformer]:
"""Read the important parameters from the config stored inside the .nemo
checkpoint.
Args:
config_path: Path to the .yaml file, usually called model_config.yaml
sentencepiece_path: Path to the sentencepiece model used to tokenize, usually called tokenizer.model
Returns:
A tuple containing, in this order, the encoder, the audio transform and the text transform
"""
augment_params = augment_params or {}
conf = OmegaConf.load(config_path)
encoder_params = conf["encoder"]
quartznet_conf = OmegaConf.to_container(encoder_params["jasper"])
body_config = quartznet_conf[1:-1]
filters = [cfg["filters"] for cfg in body_config]
kernel_sizes = [cfg["kernel"][0] for cfg in body_config]
strides = [cfg["stride"][0] for cfg in body_config]
encoder_cfg = {
"filters": filters,
"kernel_sizes": kernel_sizes,
"strides": strides,
"dropout": augment_params.pop("dropout", 0.0),
}
preprocess = conf["preprocessor"]
preprocess_cfg = {
"sample_rate": preprocess["sample_rate"],
"n_window_size": int(preprocess["window_size"] * preprocess["sample_rate"]),
"n_window_stride": int(preprocess["window_stride"] * preprocess["sample_rate"]),
"n_fft": preprocess["n_fft"],
"nfilt": preprocess["features"],
"dither": preprocess["dither"],
**augment_params,
}
labels = conf["labels"] if "labels" in conf else conf["decoder"]["vocabulary"]
encoder = CitrinetEncoder(**encoder_cfg)
text_transform = BatchTextTransformer(
tokens=fix_vocab(labels),
sentencepiece_model=sentencepiece_path,
)
audio_transform = FilterbankFeatures(**preprocess_cfg)
return (
encoder,
audio_transform,
text_transform,
)