Skip to content

Compatibility

Helper functions to load the Quartznet model from original Nemo released checkpoint files.

QuartznetCheckpoint (BaseCheckpoint)

Trained model weight checkpoints. Used by download_checkpoint and load_quartznet_checkpoint.

Note

Possible values are QuartzNet15x5Base_En,QuartzNet15x5Base_Zh,QuartzNet5x5LS_En, QuartzNet15x5NR_En, stt_ca_quartznet15x5,stt_it_quartznet15x5,stt_fr_quartznet15x5,stt_es_quartznet15x5, stt_de_quartznet15x5,stt_pl_quartznet15x5,stt_ru_quartznet15x5,stt_en_quartznet15x5, stt_zh_quartznet15x5

Source code in thunder/quartznet/compatibility.py
class QuartznetCheckpoint(BaseCheckpoint):
    """Trained model weight checkpoints.
    Used by [`download_checkpoint`][thunder.utils.download_checkpoint] and
    [`load_quartznet_checkpoint`][thunder.quartznet.compatibility.load_quartznet_checkpoint].

    Note:
        Possible values are `QuartzNet15x5Base_En`,`QuartzNet15x5Base_Zh`,`QuartzNet5x5LS_En`, `QuartzNet15x5NR_En`,
        `stt_ca_quartznet15x5`,`stt_it_quartznet15x5`,`stt_fr_quartznet15x5`,`stt_es_quartznet15x5`,
        `stt_de_quartznet15x5`,`stt_pl_quartznet15x5`,`stt_ru_quartznet15x5`,`stt_en_quartznet15x5`,
        `stt_zh_quartznet15x5`
    """
    QuartzNet15x5Base_En = "https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/QuartzNet15x5Base-En.nemo"
    QuartzNet15x5Base_Zh = "https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/QuartzNet15x5Base-Zh.nemo"
    QuartzNet5x5LS_En = "https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/QuartzNet5x5LS-En.nemo"
    QuartzNet15x5NR_En = "https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/QuartzNet15x5NR-En.nemo"

    stt_ca_quartznet15x5 = "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_ca_quartznet15x5/versions/1.0.0rc1/files/stt_ca_quartznet15x5.nemo"
    stt_it_quartznet15x5 = "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_it_quartznet15x5/versions/1.0.0rc1/files/stt_it_quartznet15x5.nemo"
    stt_fr_quartznet15x5 = "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_fr_quartznet15x5/versions/1.0.0rc1/files/stt_fr_quartznet15x5.nemo"
    stt_es_quartznet15x5 = "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_es_quartznet15x5/versions/1.0.0rc1/files/stt_es_quartznet15x5.nemo"
    stt_de_quartznet15x5 = "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_de_quartznet15x5/versions/1.0.0rc1/files/stt_de_quartznet15x5.nemo"
    stt_pl_quartznet15x5 = "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_pl_quartznet15x5/versions/1.0.0rc1/files/stt_pl_quartznet15x5.nemo"
    stt_ru_quartznet15x5 = "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_ru_quartznet15x5/versions/1.0.0rc1/files/stt_ru_quartznet15x5.nemo"
    stt_en_quartznet15x5 = "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_quartznet15x5/versions/1.0.0rc1/files/stt_en_quartznet15x5.nemo"
    stt_zh_quartznet15x5 = "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_zh_quartznet15x5/versions/1.0.0rc1/files/stt_zh_quartznet15x5.nemo"

load_components_from_quartznet_config(config_path, augment_params=None)

Read the important parameters from the config stored inside the .nemo checkpoint.

Parameters:

Name Type Description Default
config_path Union[str, pathlib.Path]

Path to the .yaml file, usually called model_config.yaml

required

Returns:

Type Description
Tuple[torch.nn.modules.module.Module, torch.nn.modules.module.Module, thunder.text_processing.transform.BatchTextTransformer]

A tuple containing, in this order, the encoder, the audio transform and the text transform

Source code in thunder/quartznet/compatibility.py
def load_components_from_quartznet_config(
    config_path: Union[str, Path],
    augment_params: AugmentParams = None,
) -> Tuple[nn.Module, nn.Module, BatchTextTransformer]:
    """Read the important parameters from the config stored inside the .nemo
    checkpoint.

    Args:
        config_path: Path to the .yaml file, usually called model_config.yaml

    Returns:
        A tuple containing, in this order, the encoder, the audio transform and the text transform
    """
    augment_params = augment_params or {}
    conf = OmegaConf.load(config_path)
    encoder_params = conf["encoder"]["params"]
    quartznet_conf = OmegaConf.to_container(encoder_params["jasper"])

    body_config = quartznet_conf[1:-2]

    filters = [cfg["filters"] for cfg in body_config]
    kernel_sizes = [cfg["kernel"][0] for cfg in body_config]
    encoder_cfg = {
        "filters": filters,
        "kernel_sizes": kernel_sizes,
        "dropout": augment_params.pop("dropout", 0.0),
    }
    preprocess = conf["preprocessor"]["params"]

    preprocess_cfg = {
        "sample_rate": preprocess["sample_rate"],
        "n_window_size": int(preprocess["window_size"] * preprocess["sample_rate"]),
        "n_window_stride": int(preprocess["window_stride"] * preprocess["sample_rate"]),
        "n_fft": preprocess["n_fft"],
        "nfilt": preprocess["features"],
        "dither": preprocess["dither"],
        **augment_params,
    }

    labels = (
        conf["labels"] if "labels" in conf else conf["decoder"]["params"]["vocabulary"]
    )

    audio_transform = FilterbankFeatures(**preprocess_cfg)
    encoder = QuartznetEncoder(**encoder_cfg)
    text_transform = BatchTextTransformer(
        tokens=OmegaConf.to_container(labels),
    )

    return (
        encoder,
        audio_transform,
        text_transform,
    )

load_quartznet_checkpoint(checkpoint, save_folder=None, augment_params=None)

Load from the original nemo checkpoint.

Parameters:

Name Type Description Default
checkpoint Union[str, thunder.quartznet.compatibility.QuartznetCheckpoint]

Path to local .nemo file or checkpoint to be downloaded locally and lodaded.

required
save_folder str

Path to save the checkpoint when downloading it. Ignored if you pass a .nemo file as the first argument.

None

Returns:

Type Description
BaseCTCModule

The model loaded from the checkpoint

Source code in thunder/quartznet/compatibility.py
def load_quartznet_checkpoint(
    checkpoint: Union[str, QuartznetCheckpoint],
    save_folder: str = None,
    augment_params: AugmentParams = None,
) -> BaseCTCModule:
    """Load from the original nemo checkpoint.

    Args:
        checkpoint: Path to local .nemo file or checkpoint to be downloaded locally and lodaded.
        save_folder: Path to save the checkpoint when downloading it. Ignored if you pass a .nemo file as the first argument.

    Returns:
        The model loaded from the checkpoint
    """
    if isinstance(checkpoint, QuartznetCheckpoint):
        nemo_filepath = download_checkpoint(checkpoint, save_folder)
    else:
        nemo_filepath = Path(checkpoint)

    with TemporaryDirectory() as extract_folder:
        extract_archive(str(nemo_filepath), extract_folder)
        extract_path = Path(extract_folder)
        config_path = extract_path / "model_config.yaml"
        (
            encoder,
            audio_transform,
            text_transform,
        ) = load_components_from_quartznet_config(config_path, augment_params)

        decoder = conv1d_decoder(1024, text_transform.num_tokens)

        weights_path = extract_path / "model_weights.ckpt"
        load_quartznet_weights(encoder, decoder, str(weights_path))
        module = BaseCTCModule(
            encoder,
            decoder,
            audio_transform,
            text_transform,
            encoder_final_dimension=1024,
        )
        return module.eval()

load_quartznet_weights(encoder, decoder, weights_path)

Load Quartznet model weights from data present inside .nemo file

Parameters:

Name Type Description Default
encoder Module

Encoder module to load the weights into

required
decoder Module

Decoder module to load the weights into

required
weights_path str

Path to the pytorch weights checkpoint

required
Source code in thunder/quartznet/compatibility.py
def load_quartznet_weights(encoder: nn.Module, decoder: nn.Module, weights_path: str):
    """Load Quartznet model weights from data present inside .nemo file

    Args:
        encoder: Encoder module to load the weights into
        decoder: Decoder module to load the weights into
        weights_path: Path to the pytorch weights checkpoint
    """
    weights = torch.load(weights_path)

    def fix_encoder_name(x: str) -> str:
        x = x.replace("encoder.", "").replace(".res.0", ".res")
        # Add another abstraction layer if it's not a masked conv
        # This is caused by the new Masked wrapper
        if ".conv" not in x:
            parts = x.split(".")
            x = ".".join(parts[:3] + ["layer", "0"] + parts[3:])
        return x

    # We remove the 'encoder.' and 'decoder.' prefix from the weights to enable
    # compatibility to load with plain nn.Modules created by reading the config
    encoder_weights = {
        fix_encoder_name(k): v for k, v in weights.items() if "encoder" in k
    }
    encoder.load_state_dict(encoder_weights, strict=True)

    decoder_weights = {
        k.replace("decoder.decoder_layers.0.", ""): v
        for k, v in weights.items()
        if "decoder" in k
    }
    decoder.load_state_dict(decoder_weights, strict=True)