Skip to content

Registry

Functionality to register the multiple checkpoints and provide a unified loading interface.

load_pretrained(checkpoint_name, **load_kwargs)

Load data from any registered checkpoint

Parameters:

Name Type Description Default
checkpoint_name Union[str, thunder.utils.BaseCheckpoint]

Base checkpoint name, like "QuartzNet5x5LS_En" or "facebook/wav2vec2-large-960h"

required

Returns:

Type Description
BaseCTCModule

Object containing the checkpoint data (encoder, decoder, transforms and additional data).

Source code in thunder/registry.py
def load_pretrained(
    checkpoint_name: Union[str, BaseCheckpoint], **load_kwargs
) -> BaseCTCModule:
    """Load data from any registered checkpoint

    Args:
        checkpoint_name: Base checkpoint name, like "QuartzNet5x5LS_En" or "facebook/wav2vec2-large-960h"

    Returns:
        Object containing the checkpoint data (encoder, decoder, transforms and additional data).
    """
    if isinstance(checkpoint_name, BaseCheckpoint):
        checkpoint_name = checkpoint_name.name
    # Special case when dealing with any huggingface model
    if "/" in checkpoint_name:
        model_data = load_huggingface_checkpoint(checkpoint_name, **load_kwargs)
    else:
        load_fn = CHECKPOINT_REGISTRY[checkpoint_name]
        model_data = load_fn(**load_kwargs)
    return model_data

register_checkpoint_enum(checkpoints, load_function)

Register all variations of some checkpoint enum with the corresponding loading function

Parameters:

Name Type Description Default
checkpoints Type[thunder.utils.BaseCheckpoint]

Base checkpoint class

required
load_function Callable[..., thunder.module.BaseCTCModule]

function to load the checkpoint, must receive one instance of checkpoints as first argument

required
Source code in thunder/registry.py
def register_checkpoint_enum(
    checkpoints: Type[BaseCheckpoint], load_function: CHECKPOINT_LOAD_FUNC_TYPE
):
    """Register all variations of some checkpoint enum with the corresponding loading function

    Args:
        checkpoints: Base checkpoint class
        load_function: function to load the checkpoint,
            must receive one instance of `checkpoints` as first argument"""
    for checkpoint in checkpoints:
        CHECKPOINT_REGISTRY.update(
            {checkpoint.name: partial(load_function, checkpoint)}
        )