Skip to content

Dataloader utils

Helper functions used by the speech dataloaders.

asr_collate(samples)

Function that collect samples and adds padding.

Parameters:

Name Type Description Default
samples List[Tuple[torch.Tensor, str]]

Samples produced by dataloader

required

Returns:

Type Description
Tuple[torch.Tensor, torch.Tensor, List[str]]

Tuple containing padded audios, audio lengths and the list of corresponding transcriptions in that order.

Source code in thunder/data/dataloader_utils.py
def asr_collate(samples: List[Tuple[Tensor, str]]) -> Tuple[Tensor, Tensor, List[str]]:
    """Function that collect samples and adds padding.

    Args:
        samples: Samples produced by dataloader

    Returns:
        Tuple containing padded audios, audio lengths and the list of corresponding transcriptions in that order.
    """
    samples = sorted(samples, key=lambda sample: sample[0].size(-1), reverse=True)
    padded_audios = pad_sequence([s[0].squeeze() for s in samples], batch_first=True)

    audio_lengths = Tensor([s[0].size(-1) for s in samples])

    texts = [s[1] for s in samples]

    return (padded_audios, audio_lengths, texts)