Registry
Functionality to register the multiple checkpoints and provide a unified loading interface.
load_pretrained(checkpoint_name, **load_kwargs)
Load data from any registered checkpoint
Parameters:
Name | Type | Description | Default |
---|---|---|---|
checkpoint_name |
Union[str, thunder.utils.BaseCheckpoint] |
Base checkpoint name, like "QuartzNet5x5LS_En" or "facebook/wav2vec2-large-960h" |
required |
Returns:
Type | Description |
---|---|
BaseCTCModule |
Object containing the checkpoint data (encoder, decoder, transforms and additional data). |
Source code in thunder/registry.py
def load_pretrained(
checkpoint_name: Union[str, BaseCheckpoint], **load_kwargs
) -> BaseCTCModule:
"""Load data from any registered checkpoint
Args:
checkpoint_name: Base checkpoint name, like "QuartzNet5x5LS_En" or "facebook/wav2vec2-large-960h"
Returns:
Object containing the checkpoint data (encoder, decoder, transforms and additional data).
"""
if isinstance(checkpoint_name, BaseCheckpoint):
checkpoint_name = checkpoint_name.name
# Special case when dealing with any huggingface model
if "/" in checkpoint_name:
model_data = load_huggingface_checkpoint(checkpoint_name, **load_kwargs)
else:
load_fn = CHECKPOINT_REGISTRY[checkpoint_name]
model_data = load_fn(**load_kwargs)
return model_data
register_checkpoint_enum(checkpoints, load_function)
Register all variations of some checkpoint enum with the corresponding loading function
Parameters:
Name | Type | Description | Default |
---|---|---|---|
checkpoints |
Type[thunder.utils.BaseCheckpoint] |
Base checkpoint class |
required |
load_function |
Callable[..., thunder.module.BaseCTCModule] |
function to load the checkpoint,
must receive one instance of |
required |
Source code in thunder/registry.py
def register_checkpoint_enum(
checkpoints: Type[BaseCheckpoint], load_function: CHECKPOINT_LOAD_FUNC_TYPE
):
"""Register all variations of some checkpoint enum with the corresponding loading function
Args:
checkpoints: Base checkpoint class
load_function: function to load the checkpoint,
must receive one instance of `checkpoints` as first argument"""
for checkpoint in checkpoints:
CHECKPOINT_REGISTRY.update(
{checkpoint.name: partial(load_function, checkpoint)}
)