Dataloader utils
Helper functions used by the speech dataloaders.
asr_collate(samples)
Function that collect samples and adds padding.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
samples |
List[Tuple[torch.Tensor, str]] |
Samples produced by dataloader |
required |
Returns:
Type | Description |
---|---|
Tuple[torch.Tensor, torch.Tensor, List[str]] |
Tuple containing padded audios, audio lengths and the list of corresponding transcriptions in that order. |
Source code in thunder/data/dataloader_utils.py
def asr_collate(samples: List[Tuple[Tensor, str]]) -> Tuple[Tensor, Tensor, List[str]]:
"""Function that collect samples and adds padding.
Args:
samples: Samples produced by dataloader
Returns:
Tuple containing padded audios, audio lengths and the list of corresponding transcriptions in that order.
"""
samples = sorted(samples, key=lambda sample: sample[0].size(-1), reverse=True)
padded_audios = pad_sequence([s[0].squeeze() for s in samples], batch_first=True)
audio_lengths = Tensor([s[0].size(-1) for s in samples])
texts = [s[1] for s in samples]
return (padded_audios, audio_lengths, texts)