Skip to content

Callbacks

Helper callback functionality, not essential to research

FinetuneEncoderDecoder (BaseFinetuning)

Source code in thunder/callbacks.py
class FinetuneEncoderDecoder(BaseFinetuning):
    def __init__(
        self,
        unfreeze_encoder_at_epoch: int = 1,
        encoder_initial_lr_div: float = 10,
        train_batchnorm: bool = True,
    ):
        """
        Finetune a encoder model based on a learning rate.

        Args:

            unfreeze_encoder_at_epoch: Epoch at which the encoder will be unfreezed.

            encoder_initial_lr_div:
                Used to scale down the encoder learning rate compared to rest of model.

            train_batchnorm: Make Batch Normalization trainable at the beginning of train.
        """
        super().__init__()
        self.unfreeze_encoder_at_epoch = unfreeze_encoder_at_epoch
        self.encoder_initial_lr_div = encoder_initial_lr_div
        self.train_batchnorm = train_batchnorm

    def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        """Check if the LightningModule has the necessary attribute before the train starts

        Args:
            trainer: Lightning Trainer
            pl_module: Lightning Module used during train

        Raises:
            Exception: If LightningModule has no nn.Module `encoder` attribute.
        """
        if hasattr(pl_module, "encoder") and isinstance(pl_module.encoder, nn.Module):
            return
        raise Exception(
            "The LightningModule should have a nn.Module `encoder` attribute"
        )

    def freeze_before_training(self, pl_module: pl.LightningModule):
        """Freeze the encoder initially before the train starts.

        Args:
            pl_module: Lightning Module
        """
        self.freeze(pl_module.encoder, train_bn=self.train_batchnorm)

    def finetune_function(
        self,
        pl_module: pl.LightningModule,
        epoch: int,
        optimizer: Optimizer,
        opt_idx: int,
    ):
        """Unfreezes the encoder at the specified epoch

        Args:
            pl_module: Lightning Module
            epoch: epoch number
            optimizer: optimizer used during training
            opt_idx: optimizer index
        """
        if epoch == self.unfreeze_encoder_at_epoch:
            self.unfreeze_and_add_param_group(
                pl_module.encoder,
                optimizer,
                initial_denom_lr=self.encoder_initial_lr_div,
                train_bn=not self.train_batchnorm,
            )

__init__(self, unfreeze_encoder_at_epoch=1, encoder_initial_lr_div=10, train_batchnorm=True) special

Finetune a encoder model based on a learning rate.

Parameters:

Name Type Description Default
unfreeze_encoder_at_epoch int

Epoch at which the encoder will be unfreezed.

1
encoder_initial_lr_div float

Used to scale down the encoder learning rate compared to rest of model.

10
train_batchnorm bool

Make Batch Normalization trainable at the beginning of train.

True
Source code in thunder/callbacks.py
def __init__(
    self,
    unfreeze_encoder_at_epoch: int = 1,
    encoder_initial_lr_div: float = 10,
    train_batchnorm: bool = True,
):
    """
    Finetune a encoder model based on a learning rate.

    Args:

        unfreeze_encoder_at_epoch: Epoch at which the encoder will be unfreezed.

        encoder_initial_lr_div:
            Used to scale down the encoder learning rate compared to rest of model.

        train_batchnorm: Make Batch Normalization trainable at the beginning of train.
    """
    super().__init__()
    self.unfreeze_encoder_at_epoch = unfreeze_encoder_at_epoch
    self.encoder_initial_lr_div = encoder_initial_lr_div
    self.train_batchnorm = train_batchnorm

finetune_function(self, pl_module, epoch, optimizer, opt_idx)

Unfreezes the encoder at the specified epoch

Parameters:

Name Type Description Default
pl_module LightningModule

Lightning Module

required
epoch int

epoch number

required
optimizer Optimizer

optimizer used during training

required
opt_idx int

optimizer index

required
Source code in thunder/callbacks.py
def finetune_function(
    self,
    pl_module: pl.LightningModule,
    epoch: int,
    optimizer: Optimizer,
    opt_idx: int,
):
    """Unfreezes the encoder at the specified epoch

    Args:
        pl_module: Lightning Module
        epoch: epoch number
        optimizer: optimizer used during training
        opt_idx: optimizer index
    """
    if epoch == self.unfreeze_encoder_at_epoch:
        self.unfreeze_and_add_param_group(
            pl_module.encoder,
            optimizer,
            initial_denom_lr=self.encoder_initial_lr_div,
            train_bn=not self.train_batchnorm,
        )

freeze_before_training(self, pl_module)

Freeze the encoder initially before the train starts.

Parameters:

Name Type Description Default
pl_module LightningModule

Lightning Module

required
Source code in thunder/callbacks.py
def freeze_before_training(self, pl_module: pl.LightningModule):
    """Freeze the encoder initially before the train starts.

    Args:
        pl_module: Lightning Module
    """
    self.freeze(pl_module.encoder, train_bn=self.train_batchnorm)

on_fit_start(self, trainer, pl_module)

Check if the LightningModule has the necessary attribute before the train starts

Parameters:

Name Type Description Default
trainer Trainer

Lightning Trainer

required
pl_module LightningModule

Lightning Module used during train

required

Exceptions:

Type Description
Exception

If LightningModule has no nn.Module encoder attribute.

Source code in thunder/callbacks.py
def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
    """Check if the LightningModule has the necessary attribute before the train starts

    Args:
        trainer: Lightning Trainer
        pl_module: Lightning Module used during train

    Raises:
        Exception: If LightningModule has no nn.Module `encoder` attribute.
    """
    if hasattr(pl_module, "encoder") and isinstance(pl_module.encoder, nn.Module):
        return
    raise Exception(
        "The LightningModule should have a nn.Module `encoder` attribute"
    )