Finetune
Module that implements easy finetuning of any model in the library.
FinetuneCTCModule (BaseCTCModule)
Source code in thunder/finetune.py
class FinetuneCTCModule(BaseCTCModule):
def __init__(
self,
checkpoint_name: str,
checkpoint_kwargs: Dict[str, Any] = None,
decoder_class: ModuleBuilderType = None,
decoder_kwargs: Dict[str, Any] = None,
tokens: List[str] = None,
text_kwargs: Dict[str, Any] = None,
optimizer_class: OptimizerBuilderType = torch.optim.AdamW,
optimizer_kwargs: Dict[str, Any] = None,
lr_scheduler_class: SchedulerBuilderType = None,
lr_scheduler_kwargs: Dict[str, Any] = None,
):
"""Generic finetune module, load any combination of encoder/decoder and custom tokens
Args:
checkpoint_name: Name of the base checkpoint to load
checkpoint_kwargs: Additional kwargs to the checkpoint loading function.
decoder_class: Optional class to override the loaded checkpoint.
decoder_kwargs: Additional kwargs to the decoder_class.
tokens: If passed a list of tokens, the decoder from the base checkpoint will be replaced by the one in decoder_class, and a new text transform will be build using those tokens.
text_kwargs: Additional kwargs to the text_tranform class, when tokens is not None.
optimizer_class: Optimizer to use during training.
optimizer_kwargs: Optional extra kwargs to the optimizer.
lr_scheduler_class: Optional class to use a learning rate scheduler with the optimizer.
lr_scheduler_kwargs: Optional extra kwargs to the learning rate scheduler.
"""
self.save_hyperparameters()
checkpoint_kwargs = checkpoint_kwargs or {}
decoder_kwargs = decoder_kwargs or {}
text_kwargs = text_kwargs or {}
if tokens is not None and decoder_class is None:
# Missing decoder
raise ValueError(
"New tokens were specified, but the module also needs to know the decoder class to initialize properly."
)
if tokens is None and decoder_class is not None:
# Missing tokens
raise ValueError(
"A new decoder was specified, but the module also needs to know the tokens to initialize properly."
)
checkpoint_data = load_pretrained(checkpoint_name, **checkpoint_kwargs)
if decoder_class is None:
# Keep original decoder/text processing
text_transform = checkpoint_data.text_transform
decoder = checkpoint_data.decoder
else:
# Changing the decoder layer and text processing
text_transform = BatchTextTransformer(tokens, **text_kwargs)
decoder = decoder_class(
checkpoint_data.encoder_final_dimension,
text_transform.num_tokens,
**decoder_kwargs,
)
super().__init__(
encoder=checkpoint_data.encoder,
decoder=decoder,
audio_transform=checkpoint_data.audio_transform,
text_transform=text_transform,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
lr_scheduler_class=lr_scheduler_class,
lr_scheduler_kwargs=lr_scheduler_kwargs,
)
__init__(self, checkpoint_name, checkpoint_kwargs=None, decoder_class=None, decoder_kwargs=None, tokens=None, text_kwargs=None, optimizer_class=<class 'torch.optim.adamw.AdamW'>, optimizer_kwargs=None, lr_scheduler_class=None, lr_scheduler_kwargs=None)
special
Generic finetune module, load any combination of encoder/decoder and custom tokens
Parameters:
Name | Type | Description | Default |
---|---|---|---|
checkpoint_name |
str |
Name of the base checkpoint to load |
required |
checkpoint_kwargs |
Dict[str, Any] |
Additional kwargs to the checkpoint loading function. |
None |
decoder_class |
Union[Type[torch.nn.modules.module.Module], Callable[..., torch.nn.modules.module.Module]] |
Optional class to override the loaded checkpoint. |
None |
decoder_kwargs |
Dict[str, Any] |
Additional kwargs to the decoder_class. |
None |
tokens |
List[str] |
If passed a list of tokens, the decoder from the base checkpoint will be replaced by the one in decoder_class, and a new text transform will be build using those tokens. |
None |
text_kwargs |
Dict[str, Any] |
Additional kwargs to the text_tranform class, when tokens is not None. |
None |
optimizer_class |
Union[Type[torch.optim.optimizer.Optimizer], Callable[..., torch.optim.optimizer.Optimizer]] |
Optimizer to use during training. |
<class 'torch.optim.adamw.AdamW'> |
optimizer_kwargs |
Dict[str, Any] |
Optional extra kwargs to the optimizer. |
None |
lr_scheduler_class |
Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau], Callable[..., Union[torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau]]] |
Optional class to use a learning rate scheduler with the optimizer. |
None |
lr_scheduler_kwargs |
Dict[str, Any] |
Optional extra kwargs to the learning rate scheduler. |
None |
Source code in thunder/finetune.py
def __init__(
self,
checkpoint_name: str,
checkpoint_kwargs: Dict[str, Any] = None,
decoder_class: ModuleBuilderType = None,
decoder_kwargs: Dict[str, Any] = None,
tokens: List[str] = None,
text_kwargs: Dict[str, Any] = None,
optimizer_class: OptimizerBuilderType = torch.optim.AdamW,
optimizer_kwargs: Dict[str, Any] = None,
lr_scheduler_class: SchedulerBuilderType = None,
lr_scheduler_kwargs: Dict[str, Any] = None,
):
"""Generic finetune module, load any combination of encoder/decoder and custom tokens
Args:
checkpoint_name: Name of the base checkpoint to load
checkpoint_kwargs: Additional kwargs to the checkpoint loading function.
decoder_class: Optional class to override the loaded checkpoint.
decoder_kwargs: Additional kwargs to the decoder_class.
tokens: If passed a list of tokens, the decoder from the base checkpoint will be replaced by the one in decoder_class, and a new text transform will be build using those tokens.
text_kwargs: Additional kwargs to the text_tranform class, when tokens is not None.
optimizer_class: Optimizer to use during training.
optimizer_kwargs: Optional extra kwargs to the optimizer.
lr_scheduler_class: Optional class to use a learning rate scheduler with the optimizer.
lr_scheduler_kwargs: Optional extra kwargs to the learning rate scheduler.
"""
self.save_hyperparameters()
checkpoint_kwargs = checkpoint_kwargs or {}
decoder_kwargs = decoder_kwargs or {}
text_kwargs = text_kwargs or {}
if tokens is not None and decoder_class is None:
# Missing decoder
raise ValueError(
"New tokens were specified, but the module also needs to know the decoder class to initialize properly."
)
if tokens is None and decoder_class is not None:
# Missing tokens
raise ValueError(
"A new decoder was specified, but the module also needs to know the tokens to initialize properly."
)
checkpoint_data = load_pretrained(checkpoint_name, **checkpoint_kwargs)
if decoder_class is None:
# Keep original decoder/text processing
text_transform = checkpoint_data.text_transform
decoder = checkpoint_data.decoder
else:
# Changing the decoder layer and text processing
text_transform = BatchTextTransformer(tokens, **text_kwargs)
decoder = decoder_class(
checkpoint_data.encoder_final_dimension,
text_transform.num_tokens,
**decoder_kwargs,
)
super().__init__(
encoder=checkpoint_data.encoder,
decoder=decoder,
audio_transform=checkpoint_data.audio_transform,
text_transform=text_transform,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
lr_scheduler_class=lr_scheduler_class,
lr_scheduler_kwargs=lr_scheduler_kwargs,
)