Utils
Utility functions
BaseCheckpoint (str, Enum)
Base class that represents a pretrained model checkpoint.
Source code in thunder/utils.py
class BaseCheckpoint(str, Enum):
"""Base class that represents a pretrained model checkpoint."""
@classmethod
def from_string(cls, name: str) -> "BaseCheckpoint":
"""Creates enum value from string. Helper to use with argparse/hydra
Args:
name: Name of the checkpoint
Raises:
ValueError: Name provided is not a valid checkpoint
Returns:
Enum value corresponding to the name
"""
try:
return cls[name]
except KeyError as option_does_not_exist:
raise ValueError(
"Name provided is not a valid checkpoint"
) from option_does_not_exist
audio_len(item)
Returns the length of the audio file
Parameters:
Name | Type | Description | Default |
---|---|---|---|
item |
Union[pathlib.Path, str] |
Audio path |
required |
Returns:
Type | Description |
---|---|
float |
Lenght in seconds of the audio |
Source code in thunder/utils.py
def audio_len(item: Union[Path, str]) -> float:
"""Returns the length of the audio file
Args:
item: Audio path
Returns:
Lenght in seconds of the audio
"""
metadata = torchaudio.info(item)
return metadata.num_frames / metadata.sample_rate
chain_calls(*funcs)
Chain multiple functions that take only one argument, producing a new function that is the result of calling the individual functions in sequence.
Examples:
f1 = lambda x: 2 * x
f2 = lambda x: 3 * x
f3 = lambda x: 4 * x
g = chain_calls(f1, f2, f3)
assert g(1) == 24
Returns:
Type | Description |
---|---|
Callable |
Single chained function |
Source code in thunder/utils.py
def chain_calls(*funcs: List[Callable]) -> Callable:
"""Chain multiple functions that take only one argument, producing a new function that is the result
of calling the individual functions in sequence.
Example:
```python
f1 = lambda x: 2 * x
f2 = lambda x: 3 * x
f3 = lambda x: 4 * x
g = chain_calls(f1, f2, f3)
assert g(1) == 24
```
Returns:
Single chained function
"""
def call(x, f):
return f(x)
def _inner(arg):
return functools.reduce(call, funcs, arg)
return _inner
download_checkpoint(name, checkpoint_folder=None)
Download checkpoint by identifier.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
BaseCheckpoint |
Model identifier. Check checkpoint_archives.keys() |
required |
checkpoint_folder |
str |
Folder where the checkpoint will be saved to. |
None |
Returns:
Type | Description |
---|---|
Path |
Path to the saved checkpoint file. |
Source code in thunder/utils.py
def download_checkpoint(name: BaseCheckpoint, checkpoint_folder: str = None) -> Path:
"""Download checkpoint by identifier.
Args:
name: Model identifier. Check checkpoint_archives.keys()
checkpoint_folder: Folder where the checkpoint will be saved to.
Returns:
Path to the saved checkpoint file.
"""
if checkpoint_folder is None:
checkpoint_folder = get_default_cache_folder()
url = name.value
filename = url.split("/")[-1]
checkpoint_path = Path(checkpoint_folder) / filename
if not checkpoint_path.exists():
wget.download(url, out=str(checkpoint_path))
return checkpoint_path
get_default_cache_folder()
Get the default folder where the cached stuff will be saved.
Returns:
Type | Description |
---|---|
Path |
Path of the cache folder. |
Source code in thunder/utils.py
def get_default_cache_folder() -> Path:
"""Get the default folder where the cached stuff will be saved.
Returns:
Path of the cache folder.
"""
folder = Path.home() / ".thunder"
folder.mkdir(exist_ok=True)
return folder
get_files(directory, extension)
Find all files in directory with extension.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
directory |
Union[str, pathlib.Path] |
Directory to recursively find the files |
required |
extension |
str |
File extension to search for |
required |
Returns:
Type | Description |
---|---|
List[pathlib.Path] |
List of all the files that match the extension |
Source code in thunder/utils.py
def get_files(directory: Union[str, Path], extension: str) -> List[Path]:
"""Find all files in directory with extension.
Args:
directory: Directory to recursively find the files
extension: File extension to search for
Returns:
List of all the files that match the extension
"""
files_found = []
for root, _, files in os.walk(directory, followlinks=True):
files_found += [Path(root) / f for f in files if f.endswith(extension)]
return files_found