Skip to content

Utils

Utility functions

BaseCheckpoint (str, Enum)

Base class that represents a pretrained model checkpoint.

Source code in thunder/utils.py
class BaseCheckpoint(str, Enum):
    """Base class that represents a pretrained model checkpoint."""

    @classmethod
    def from_string(cls, name: str) -> "BaseCheckpoint":
        """Creates enum value from string. Helper to use with argparse/hydra

        Args:
            name: Name of the checkpoint

        Raises:
            ValueError: Name provided is not a valid checkpoint

        Returns:
            Enum value corresponding to the name
        """
        try:
            return cls[name]
        except KeyError as option_does_not_exist:
            raise ValueError(
                "Name provided is not a valid checkpoint"
            ) from option_does_not_exist

audio_len(item)

Returns the length of the audio file

Parameters:

Name Type Description Default
item Union[pathlib.Path, str]

Audio path

required

Returns:

Type Description
float

Lenght in seconds of the audio

Source code in thunder/utils.py
def audio_len(item: Union[Path, str]) -> float:
    """Returns the length of the audio file

    Args:
        item: Audio path

    Returns:
        Lenght in seconds of the audio
    """
    metadata = torchaudio.info(item)
    return metadata.num_frames / metadata.sample_rate

chain_calls(*funcs)

Chain multiple functions that take only one argument, producing a new function that is the result of calling the individual functions in sequence.

Examples:

f1 = lambda x: 2 * x
f2 = lambda x: 3 * x
f3 = lambda x: 4 * x
g = chain_calls(f1, f2, f3)
assert g(1) == 24

Returns:

Type Description
Callable

Single chained function

Source code in thunder/utils.py
def chain_calls(*funcs: List[Callable]) -> Callable:
    """Chain multiple functions that take only one argument, producing a new function that is the result
    of calling the individual functions in sequence.

    Example:
    ```python
    f1 = lambda x: 2 * x
    f2 = lambda x: 3 * x
    f3 = lambda x: 4 * x
    g = chain_calls(f1, f2, f3)
    assert g(1) == 24
    ```

    Returns:
        Single chained function
    """

    def call(x, f):
        return f(x)

    def _inner(arg):
        return functools.reduce(call, funcs, arg)

    return _inner

download_checkpoint(name, checkpoint_folder=None)

Download checkpoint by identifier.

Parameters:

Name Type Description Default
name BaseCheckpoint

Model identifier. Check checkpoint_archives.keys()

required
checkpoint_folder str

Folder where the checkpoint will be saved to.

None

Returns:

Type Description
Path

Path to the saved checkpoint file.

Source code in thunder/utils.py
def download_checkpoint(name: BaseCheckpoint, checkpoint_folder: str = None) -> Path:
    """Download checkpoint by identifier.

    Args:
        name: Model identifier. Check checkpoint_archives.keys()
        checkpoint_folder: Folder where the checkpoint will be saved to.

    Returns:
        Path to the saved checkpoint file.
    """
    if checkpoint_folder is None:
        checkpoint_folder = get_default_cache_folder()

    url = name.value
    filename = url.split("/")[-1]
    checkpoint_path = Path(checkpoint_folder) / filename
    if not checkpoint_path.exists():
        wget.download(url, out=str(checkpoint_path))

    return checkpoint_path

get_default_cache_folder()

Get the default folder where the cached stuff will be saved.

Returns:

Type Description
Path

Path of the cache folder.

Source code in thunder/utils.py
def get_default_cache_folder() -> Path:
    """Get the default folder where the cached stuff will be saved.

    Returns:
        Path of the cache folder.
    """
    folder = Path.home() / ".thunder"
    folder.mkdir(exist_ok=True)
    return folder

get_files(directory, extension)

Find all files in directory with extension.

Parameters:

Name Type Description Default
directory Union[str, pathlib.Path]

Directory to recursively find the files

required
extension str

File extension to search for

required

Returns:

Type Description
List[pathlib.Path]

List of all the files that match the extension

Source code in thunder/utils.py
def get_files(directory: Union[str, Path], extension: str) -> List[Path]:
    """Find all files in directory with extension.

    Args:
        directory: Directory to recursively find the files
        extension: File extension to search for

    Returns:
        List of all the files that match the extension
    """
    files_found = []

    for root, _, files in os.walk(directory, followlinks=True):
        files_found += [Path(root) / f for f in files if f.endswith(extension)]
    return files_found