Skip to content

Finetune

Module that implements easy finetuning of any model in the library.

FinetuneCTCModule (BaseCTCModule)

Source code in thunder/finetune.py
class FinetuneCTCModule(BaseCTCModule):
    def __init__(
        self,
        checkpoint_name: str,
        checkpoint_kwargs: Dict[str, Any] = None,
        decoder_class: ModuleBuilderType = None,
        decoder_kwargs: Dict[str, Any] = None,
        tokens: List[str] = None,
        text_kwargs: Dict[str, Any] = None,
        optimizer_class: OptimizerBuilderType = torch.optim.AdamW,
        optimizer_kwargs: Dict[str, Any] = None,
        lr_scheduler_class: SchedulerBuilderType = None,
        lr_scheduler_kwargs: Dict[str, Any] = None,
    ):
        """Generic finetune module, load any combination of encoder/decoder and custom tokens

        Args:
            checkpoint_name: Name of the base checkpoint to load
            checkpoint_kwargs: Additional kwargs to the checkpoint loading function.
            decoder_class: Optional class to override the loaded checkpoint.
            decoder_kwargs: Additional kwargs to the decoder_class.
            tokens: If passed a list of tokens, the decoder from the base checkpoint will be replaced by the one in decoder_class, and a new text transform will be build using those tokens.
            text_kwargs: Additional kwargs to the text_tranform class, when tokens is not None.
            optimizer_class: Optimizer to use during training.
            optimizer_kwargs: Optional extra kwargs to the optimizer.
            lr_scheduler_class: Optional class to use a learning rate scheduler with the optimizer.
            lr_scheduler_kwargs: Optional extra kwargs to the learning rate scheduler.
        """
        self.save_hyperparameters()
        checkpoint_kwargs = checkpoint_kwargs or {}
        decoder_kwargs = decoder_kwargs or {}
        text_kwargs = text_kwargs or {}

        if tokens is not None and decoder_class is None:
            # Missing decoder
            raise ValueError(
                "New tokens were specified, but the module also needs to know the decoder class to initialize properly."
            )

        if tokens is None and decoder_class is not None:
            # Missing tokens
            raise ValueError(
                "A new decoder was specified, but the module also needs to know the tokens to initialize properly."
            )

        checkpoint_data = load_pretrained(checkpoint_name, **checkpoint_kwargs)

        if decoder_class is None:
            # Keep original decoder/text processing
            text_transform = checkpoint_data.text_transform
            decoder = checkpoint_data.decoder
        else:
            # Changing the decoder layer and text processing
            text_transform = BatchTextTransformer(tokens, **text_kwargs)
            decoder = decoder_class(
                checkpoint_data.encoder_final_dimension,
                text_transform.num_tokens,
                **decoder_kwargs,
            )

        super().__init__(
            encoder=checkpoint_data.encoder,
            decoder=decoder,
            audio_transform=checkpoint_data.audio_transform,
            text_transform=text_transform,
            optimizer_class=optimizer_class,
            optimizer_kwargs=optimizer_kwargs,
            lr_scheduler_class=lr_scheduler_class,
            lr_scheduler_kwargs=lr_scheduler_kwargs,
        )

__init__(self, checkpoint_name, checkpoint_kwargs=None, decoder_class=None, decoder_kwargs=None, tokens=None, text_kwargs=None, optimizer_class=<class 'torch.optim.adamw.AdamW'>, optimizer_kwargs=None, lr_scheduler_class=None, lr_scheduler_kwargs=None) special

Generic finetune module, load any combination of encoder/decoder and custom tokens

Parameters:

Name Type Description Default
checkpoint_name str

Name of the base checkpoint to load

required
checkpoint_kwargs Dict[str, Any]

Additional kwargs to the checkpoint loading function.

None
decoder_class Union[Type[torch.nn.modules.module.Module], Callable[..., torch.nn.modules.module.Module]]

Optional class to override the loaded checkpoint.

None
decoder_kwargs Dict[str, Any]

Additional kwargs to the decoder_class.

None
tokens List[str]

If passed a list of tokens, the decoder from the base checkpoint will be replaced by the one in decoder_class, and a new text transform will be build using those tokens.

None
text_kwargs Dict[str, Any]

Additional kwargs to the text_tranform class, when tokens is not None.

None
optimizer_class Union[Type[torch.optim.optimizer.Optimizer], Callable[..., torch.optim.optimizer.Optimizer]]

Optimizer to use during training.

<class 'torch.optim.adamw.AdamW'>
optimizer_kwargs Dict[str, Any]

Optional extra kwargs to the optimizer.

None
lr_scheduler_class Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau], Callable[..., Union[torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau]]]

Optional class to use a learning rate scheduler with the optimizer.

None
lr_scheduler_kwargs Dict[str, Any]

Optional extra kwargs to the learning rate scheduler.

None
Source code in thunder/finetune.py
def __init__(
    self,
    checkpoint_name: str,
    checkpoint_kwargs: Dict[str, Any] = None,
    decoder_class: ModuleBuilderType = None,
    decoder_kwargs: Dict[str, Any] = None,
    tokens: List[str] = None,
    text_kwargs: Dict[str, Any] = None,
    optimizer_class: OptimizerBuilderType = torch.optim.AdamW,
    optimizer_kwargs: Dict[str, Any] = None,
    lr_scheduler_class: SchedulerBuilderType = None,
    lr_scheduler_kwargs: Dict[str, Any] = None,
):
    """Generic finetune module, load any combination of encoder/decoder and custom tokens

    Args:
        checkpoint_name: Name of the base checkpoint to load
        checkpoint_kwargs: Additional kwargs to the checkpoint loading function.
        decoder_class: Optional class to override the loaded checkpoint.
        decoder_kwargs: Additional kwargs to the decoder_class.
        tokens: If passed a list of tokens, the decoder from the base checkpoint will be replaced by the one in decoder_class, and a new text transform will be build using those tokens.
        text_kwargs: Additional kwargs to the text_tranform class, when tokens is not None.
        optimizer_class: Optimizer to use during training.
        optimizer_kwargs: Optional extra kwargs to the optimizer.
        lr_scheduler_class: Optional class to use a learning rate scheduler with the optimizer.
        lr_scheduler_kwargs: Optional extra kwargs to the learning rate scheduler.
    """
    self.save_hyperparameters()
    checkpoint_kwargs = checkpoint_kwargs or {}
    decoder_kwargs = decoder_kwargs or {}
    text_kwargs = text_kwargs or {}

    if tokens is not None and decoder_class is None:
        # Missing decoder
        raise ValueError(
            "New tokens were specified, but the module also needs to know the decoder class to initialize properly."
        )

    if tokens is None and decoder_class is not None:
        # Missing tokens
        raise ValueError(
            "A new decoder was specified, but the module also needs to know the tokens to initialize properly."
        )

    checkpoint_data = load_pretrained(checkpoint_name, **checkpoint_kwargs)

    if decoder_class is None:
        # Keep original decoder/text processing
        text_transform = checkpoint_data.text_transform
        decoder = checkpoint_data.decoder
    else:
        # Changing the decoder layer and text processing
        text_transform = BatchTextTransformer(tokens, **text_kwargs)
        decoder = decoder_class(
            checkpoint_data.encoder_final_dimension,
            text_transform.num_tokens,
            **decoder_kwargs,
        )

    super().__init__(
        encoder=checkpoint_data.encoder,
        decoder=decoder,
        audio_transform=checkpoint_data.audio_transform,
        text_transform=text_transform,
        optimizer_class=optimizer_class,
        optimizer_kwargs=optimizer_kwargs,
        lr_scheduler_class=lr_scheduler_class,
        lr_scheduler_kwargs=lr_scheduler_kwargs,
    )