Skip to content

Compatibility

Helper functions to load huggingface speech recognition models.

load_huggingface_checkpoint(model_name, **model_kwargs)

Load huggingface model and convert to thunder BaseCTCModule

Parameters:

Name Type Description Default
model_name str

huggingface identifier of the model, like "facebook/wav2vec2-large-960h"

required
model_kwargs Dict[str, Any]

extra keyword arguments to be passed to AutoModelForCTC.from_pretrained

{}

Returns:

Type Description
BaseCTCModule

Thunder module containing the huggingface model.

Source code in thunder/huggingface/compatibility.py
def load_huggingface_checkpoint(
    model_name: str, **model_kwargs: Dict[str, Any]
) -> BaseCTCModule:
    """Load huggingface model and convert to thunder [`BaseCTCModule`][thunder.module.BaseCTCModule]

    Args:
        model_name: huggingface identifier of the model, like "facebook/wav2vec2-large-960h"
        model_kwargs: extra keyword arguments to be passed to `AutoModelForCTC.from_pretrained`

    Returns:
        Thunder module containing the huggingface model.
    """
    model = AutoModelForCTC.from_pretrained(model_name, **model_kwargs)
    feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
    # Some models only contain the encoder, and no tokenizer
    # In that case we need to warn the user to fix it before training
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        text_transform = _tok_to_transform(tokenizer)
        decoder = linear_decoder(
            model.base_model.config.hidden_size,
            text_transform.num_tokens,
            decoder_dropout=0.0,
        )
        if hasattr(model, "lm_head"):
            decoder[2].load_state_dict(model.lm_head.state_dict())
    except (OSError, KeyError):
        warn(
            UserWarning(
                "Huggingface model is missing the tokenizer! decoder and text_transform were not initialized"
            )
        )
        text_transform = None
        decoder = None

    module = BaseCTCModule(
        encoder=_HuggingFaceEncoderAdapt(
            model.base_model,
            mask_input=feature_extractor.return_attention_mask,
        ),
        decoder=decoder,
        text_transform=text_transform,
        audio_transform=Wav2Vec2Preprocess(
            mask_input=feature_extractor.return_attention_mask,
        ),
        encoder_final_dimension=model.base_model.config.hidden_size,
    )
    return module.eval()

prepare_scriptable_wav2vec(module, quantized=False)

Converts thunder module containing a wav2vec2 model to be scriptable.

Parameters:

Name Type Description Default
module BaseCTCModule

Module containing wav2vec2

required
quantized bool

If true, also performs quantization of the model

False

Returns:

Type Description
BaseCTCModule

Modified module ready to call module.to_torchscript()

Source code in thunder/huggingface/compatibility.py
def prepare_scriptable_wav2vec(
    module: BaseCTCModule, quantized: bool = False
) -> BaseCTCModule:
    """Converts thunder module containing a wav2vec2 model to be scriptable.

    Args:
        module: Module containing wav2vec2
        quantized: If true, also performs quantization of the model

    Returns:
        Modified module ready to call module.to_torchscript()
    """
    imported = import_huggingface_model(module.encoder.original_encoder)
    if quantized:
        imported.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
        imported = torch.quantization.quantize_dynamic(
            imported, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8
        )
    module.encoder = imported
    module.decoder = nn.Sequential(*module.decoder[1:])
    return module