Skip to content

Compatibility

Helper functions to load the Citrinet model from original Nemo released checkpoint files.

CitrinetCheckpoint (BaseCheckpoint)

Trained model weight checkpoints. Used by download_checkpoint and load_citrinet_checkpoint.

Note

Possible values are stt_en_citrinet_256,stt_en_citrinet_512,stt_en_citrinet_1024, stt_es_citrinet_512

Source code in thunder/citrinet/compatibility.py
class CitrinetCheckpoint(BaseCheckpoint):
    """Trained model weight checkpoints.
    Used by [`download_checkpoint`][thunder.utils.download_checkpoint] and
    [`load_citrinet_checkpoint`][thunder.citrinet.compatibility.load_citrinet_checkpoint].

    Note:
        Possible values are `stt_en_citrinet_256`,`stt_en_citrinet_512`,`stt_en_citrinet_1024`, `stt_es_citrinet_512`
    """
    stt_en_citrinet_256 = "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_citrinet_256/versions/1.0.0rc1/files/stt_en_citrinet_256.nemo"
    stt_en_citrinet_512 = "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_citrinet_512/versions/1.0.0rc1/files/stt_en_citrinet_512.nemo"
    stt_en_citrinet_1024 = "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_citrinet_1024/versions/1.0.0rc1/files/stt_en_citrinet_1024.nemo"
    stt_es_citrinet_512 = "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_es_citrinet_512/versions/1.0.0/files/stt_es_citrinet_512.nemo"

fix_vocab(vocab_tokens)

Transform the nemo vocab tokens back to the sentencepiece sytle with the _ prefix

Parameters:

Name Type Description Default
vocab_tokens List[str]

List of tokens in the vocabulary

required

Returns:

Type Description
List[str]

New list of tokens with the new prefix

Source code in thunder/citrinet/compatibility.py
def fix_vocab(vocab_tokens: List[str]) -> List[str]:
    """Transform the nemo vocab tokens back to the sentencepiece sytle
    with the _ prefix

    Args:
        vocab_tokens: List of tokens in the vocabulary

    Returns:
        New list of tokens with the new prefix
    """
    out_tokens = []
    for token in vocab_tokens:
        if token.startswith("##"):
            out_tokens.append(token[2:])
        else:
            out_tokens.append("▁" + token)
    return out_tokens

load_citrinet_checkpoint(checkpoint, save_folder=None, augment_params=None)

Load from the original nemo checkpoint.

Parameters:

Name Type Description Default
checkpoint Union[str, thunder.citrinet.compatibility.CitrinetCheckpoint]

Path to local .nemo file or checkpoint to be downloaded locally and lodaded.

required
save_folder str

Path to save the checkpoint when downloading it. Ignored if you pass a .nemo file as the first argument.

None

Returns:

Type Description
BaseCTCModule

The model loaded from the checkpoint

Source code in thunder/citrinet/compatibility.py
def load_citrinet_checkpoint(
    checkpoint: Union[str, CitrinetCheckpoint],
    save_folder: str = None,
    augment_params: AugmentParams = None,
) -> BaseCTCModule:
    """Load from the original nemo checkpoint.

    Args:
        checkpoint: Path to local .nemo file or checkpoint to be downloaded locally and lodaded.
        save_folder: Path to save the checkpoint when downloading it. Ignored if you pass a .nemo file as the first argument.

    Returns:
        The model loaded from the checkpoint
    """
    if isinstance(checkpoint, CitrinetCheckpoint):
        nemo_filepath = download_checkpoint(checkpoint, save_folder)
    else:
        nemo_filepath = checkpoint

    with TemporaryDirectory() as extract_folder:
        extract_archive(str(nemo_filepath), extract_folder)
        extract_path = Path(extract_folder)
        config_path = extract_path / "model_config.yaml"
        sentencepiece_path = str(extract_path / "tokenizer.model")
        (
            encoder,
            audio_transform,
            text_transform,
        ) = load_components_from_citrinet_config(
            config_path, sentencepiece_path, augment_params
        )

        decoder = conv1d_decoder(640, num_classes=text_transform.num_tokens)

        weights_path = extract_path / "model_weights.ckpt"
        load_quartznet_weights(encoder, decoder, str(weights_path))
        module = BaseCTCModule(
            encoder,
            decoder,
            audio_transform,
            text_transform,
            encoder_final_dimension=640,
        )
        return module.eval()

load_components_from_citrinet_config(config_path, sentencepiece_path, augment_params=None)

Read the important parameters from the config stored inside the .nemo checkpoint.

Parameters:

Name Type Description Default
config_path Union[str, pathlib.Path]

Path to the .yaml file, usually called model_config.yaml

required
sentencepiece_path Union[str, pathlib.Path]

Path to the sentencepiece model used to tokenize, usually called tokenizer.model

required

Returns:

Type Description
Tuple[torch.nn.modules.module.Module, torch.nn.modules.module.Module, thunder.text_processing.transform.BatchTextTransformer]

A tuple containing, in this order, the encoder, the audio transform and the text transform

Source code in thunder/citrinet/compatibility.py
def load_components_from_citrinet_config(
    config_path: Union[str, Path],
    sentencepiece_path: Union[str, Path],
    augment_params: AugmentParams = None,
) -> Tuple[nn.Module, nn.Module, BatchTextTransformer]:
    """Read the important parameters from the config stored inside the .nemo
    checkpoint.

    Args:
        config_path: Path to the .yaml file, usually called model_config.yaml
        sentencepiece_path: Path to the sentencepiece model used to tokenize, usually called tokenizer.model

    Returns:
        A tuple containing, in this order, the encoder, the audio transform and the text transform
    """
    augment_params = augment_params or {}

    conf = OmegaConf.load(config_path)
    encoder_params = conf["encoder"]
    quartznet_conf = OmegaConf.to_container(encoder_params["jasper"])

    body_config = quartznet_conf[1:-1]

    filters = [cfg["filters"] for cfg in body_config]
    kernel_sizes = [cfg["kernel"][0] for cfg in body_config]
    strides = [cfg["stride"][0] for cfg in body_config]
    encoder_cfg = {
        "filters": filters,
        "kernel_sizes": kernel_sizes,
        "strides": strides,
        "dropout": augment_params.pop("dropout", 0.0),
    }
    preprocess = conf["preprocessor"]

    preprocess_cfg = {
        "sample_rate": preprocess["sample_rate"],
        "n_window_size": int(preprocess["window_size"] * preprocess["sample_rate"]),
        "n_window_stride": int(preprocess["window_stride"] * preprocess["sample_rate"]),
        "n_fft": preprocess["n_fft"],
        "nfilt": preprocess["features"],
        "dither": preprocess["dither"],
        **augment_params,
    }

    labels = conf["labels"] if "labels" in conf else conf["decoder"]["vocabulary"]

    encoder = CitrinetEncoder(**encoder_cfg)
    text_transform = BatchTextTransformer(
        tokens=fix_vocab(labels),
        sentencepiece_model=sentencepiece_path,
    )
    audio_transform = FilterbankFeatures(**preprocess_cfg)

    return (
        encoder,
        audio_transform,
        text_transform,
    )