Compatibility
Helper functions to load huggingface speech recognition models.
load_huggingface_checkpoint(model_name, **model_kwargs)
Load huggingface model and convert to thunder BaseCTCModule
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name |
str |
huggingface identifier of the model, like "facebook/wav2vec2-large-960h" |
required |
model_kwargs |
Dict[str, Any] |
extra keyword arguments to be passed to |
{} |
Returns:
Type | Description |
---|---|
BaseCTCModule |
Thunder module containing the huggingface model. |
Source code in thunder/huggingface/compatibility.py
def load_huggingface_checkpoint(
model_name: str, **model_kwargs: Dict[str, Any]
) -> BaseCTCModule:
"""Load huggingface model and convert to thunder [`BaseCTCModule`][thunder.module.BaseCTCModule]
Args:
model_name: huggingface identifier of the model, like "facebook/wav2vec2-large-960h"
model_kwargs: extra keyword arguments to be passed to `AutoModelForCTC.from_pretrained`
Returns:
Thunder module containing the huggingface model.
"""
model = AutoModelForCTC.from_pretrained(model_name, **model_kwargs)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
# Some models only contain the encoder, and no tokenizer
# In that case we need to warn the user to fix it before training
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
text_transform = _tok_to_transform(tokenizer)
decoder = linear_decoder(
model.base_model.config.hidden_size,
text_transform.num_tokens,
decoder_dropout=0.0,
)
if hasattr(model, "lm_head"):
decoder[2].load_state_dict(model.lm_head.state_dict())
except (OSError, KeyError):
warn(
UserWarning(
"Huggingface model is missing the tokenizer! decoder and text_transform were not initialized"
)
)
text_transform = None
decoder = None
module = BaseCTCModule(
encoder=_HuggingFaceEncoderAdapt(
model.base_model,
mask_input=feature_extractor.return_attention_mask,
),
decoder=decoder,
text_transform=text_transform,
audio_transform=Wav2Vec2Preprocess(
mask_input=feature_extractor.return_attention_mask,
),
encoder_final_dimension=model.base_model.config.hidden_size,
)
return module.eval()
prepare_scriptable_wav2vec(module, quantized=False)
Converts thunder module containing a wav2vec2 model to be scriptable.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
module |
BaseCTCModule |
Module containing wav2vec2 |
required |
quantized |
bool |
If true, also performs quantization of the model |
False |
Returns:
Type | Description |
---|---|
BaseCTCModule |
Modified module ready to call module.to_torchscript() |
Source code in thunder/huggingface/compatibility.py
def prepare_scriptable_wav2vec(
module: BaseCTCModule, quantized: bool = False
) -> BaseCTCModule:
"""Converts thunder module containing a wav2vec2 model to be scriptable.
Args:
module: Module containing wav2vec2
quantized: If true, also performs quantization of the model
Returns:
Modified module ready to call module.to_torchscript()
"""
imported = import_huggingface_model(module.encoder.original_encoder)
if quantized:
imported.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
imported = torch.quantization.quantize_dynamic(
imported, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8
)
module.encoder = imported
module.decoder = nn.Sequential(*module.decoder[1:])
return module