Skip to content

Quick reference guide

How to load a Quartznet/Citrinet .nemo file?

from thunder.quartznet.compatibility import load_quartznet_checkpoint
from thunder.citrinet.compatibility import load_citrinet_checkpoint

module = load_quartznet_checkpoint("/path/to/quartznet.nemo")
module = load_citrinet_checkpoint("/path/to/citrinet.nemo")

How to export models with special restrictions?

Case 1: Using Quartznet or Citrinet on platforms that doesnt support FFT (android, onnx):

from thunder.registry import load_pretrained
from thunder.quartznet.transform import patch_stft
import torch

module = load_pretrained("QuartzNet5x5LS_En")
module.audio_transform = patch_stft(module.audio_transform)
module.to_torchscript("model_ready_for_inference.pt")

Case 2: Wav2vec 2.0 using torchscript

from thunder.registry import load_pretrained
from thunder.huggingface.compatibility import prepare_scriptable_wav2vec

module = load_pretrained("facebook/wav2vec2-large-960h")
module = prepare_scriptable_wav2vec(module)
module.to_torchscript("model_ready_for_inference.pt")

What if I want the probabilities instead of the captions during inference?

Instead of model.predict(audio), use just model(audio)

import torch
import torchaudio

model = torch.jit.load("model_ready_for_inference.pt")
loader = torch.jit.load("audio_loader.pt")
# Open audio
audio = loader("audio_file.wav")
probs = model(audio)
# If you also want the transcriptions:
transcriptions = model.text_transform.decode_prediction(probs.argmax(1))

How to finetune a model if I already have the nemo manifests prepared?

import pytorch_lightning as pl

from thunder.data.datamodule import ManifestDatamodule
from thunder.registry import load_pretrained
from thunder.callbacks import FinetuneEncoderDecoder

dm = ManifestDatamodule(
    train_manifest="/path/to/train_manifest.json",
    val_manifest="/path/to/val_manifest.json",
    test_manifest="/path/to/test_manifest.json",
)

model = load_pretrained("QuartzNet5x5LS_En")

trainer = pl.Trainer(
    accelerator='gpu',
    devices=-1, # Use all gpus
    max_epochs=10,
    callbacks=[FinetuneEncoderDecoder(unfreeze_encoder_at_epoch=1)],
)

trainer.fit(model=model, datamodule=dm)

How to get the tokens from my dataset?

from thunder.text_processing.tokenizer import char_tokenizer, get_most_frequent_tokens

my_datamodule = CustomDatamodule(...)
my_datamodule.prepare_data()
my_datamodule.setup(None)

train_corpus = " ".join(my_datamodule.train_dataset.all_outputs())
tokens = get_most_frequent_tokens(train_corpus, char_tokenizer)