Datamodule
Implements pytorch lightning's Datamodule for audio datasets.
BaseDataModule (LightningDataModule)
Source code in thunder/data/datamodule.py
class BaseDataModule(LightningDataModule):
def __init__(
self,
batch_size: int = 10,
num_workers: int = 8,
):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
def get_dataset(self, split: str) -> BaseSpeechDataset:
"""Function to get the corresponding dataset to the specified split.
This should be implemented by subclasses.
Args:
split: One of "train", "valid" or "test".
Returns:
The corresponding dataset.
"""
raise NotImplementedError()
def setup(self, stage: Optional[str] = None):
if stage in (None, "fit"):
self.train_dataset = self.get_dataset(split="train")
self.val_dataset = self.get_dataset(split="valid")
if stage in (None, "test"):
self.test_dataset = self.get_dataset(split="test")
def train_dataloader(self) -> DataLoader:
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
collate_fn=asr_collate,
num_workers=self.num_workers,
shuffle=True,
pin_memory=True,
)
def val_dataloader(self) -> DataLoader:
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=asr_collate,
num_workers=self.num_workers,
pin_memory=True,
)
def test_dataloader(self) -> DataLoader:
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=asr_collate,
num_workers=self.num_workers,
pin_memory=True,
)
@property
def steps_per_epoch(self) -> int:
"""Number of steps for each training epoch. Used for learning rate scheduling.
Returns:
Number of steps
"""
return len(self.train_dataset) // self.batch_size
steps_per_epoch: int
property
readonly
Number of steps for each training epoch. Used for learning rate scheduling.
Returns:
Type | Description |
---|---|
int |
Number of steps |
get_dataset(self, split)
Function to get the corresponding dataset to the specified split. This should be implemented by subclasses.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
split |
str |
One of "train", "valid" or "test". |
required |
Returns:
Type | Description |
---|---|
BaseSpeechDataset |
The corresponding dataset. |
Source code in thunder/data/datamodule.py
def get_dataset(self, split: str) -> BaseSpeechDataset:
"""Function to get the corresponding dataset to the specified split.
This should be implemented by subclasses.
Args:
split: One of "train", "valid" or "test".
Returns:
The corresponding dataset.
"""
raise NotImplementedError()
setup(self, stage=None)
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stage |
Optional[str] |
either |
None |
Example::
class LitModel(...):
def __init__(self):
self.l1 = None
def prepare_data(self):
download_data()
tokenize()
# don't do this
self.something = else
def setup(self, stage):
data = load_data(...)
self.l1 = nn.Linear(28, data.num_classes)
Source code in thunder/data/datamodule.py
def setup(self, stage: Optional[str] = None):
if stage in (None, "fit"):
self.train_dataset = self.get_dataset(split="train")
self.val_dataset = self.get_dataset(split="valid")
if stage in (None, "test"):
self.test_dataset = self.get_dataset(split="test")
test_dataloader(self)
Implement one or multiple PyTorch DataLoaders for testing.
For data processing use the following pattern:
- download in :meth:`prepare_data`
- process and split in :meth:`setup`
However, the above are only necessary for distributed processing.
.. warning:: do not assign state in prepare_data
- :meth:
~pytorch_lightning.trainer.trainer.Trainer.test
- :meth:
prepare_data
- :meth:
setup
Note
Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Returns:
Type | Description |
---|---|
A |
class: |
Example::
def test_dataloader(self):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform,
download=True)
loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=self.batch_size,
shuffle=False
)
return loader
# can also return multiple dataloaders
def test_dataloader(self):
return [loader_a, loader_b, ..., loader_n]
Note
If you don't need a test dataset and a :meth:test_step
, you don't need to implement
this method.
Note
In the case where you return multiple test dataloaders, the :meth:test_step
will have an argument dataloader_idx
which matches the order here.
Source code in thunder/data/datamodule.py
def test_dataloader(self) -> DataLoader:
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=asr_collate,
num_workers=self.num_workers,
pin_memory=True,
)
train_dataloader(self)
Implement one or more PyTorch DataLoaders for training.
Returns:
Type | Description |
---|---|
A collection of |
class: |
The dataloader you return will not be reloaded unless you set
:paramref:~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs
to
a positive integer.
For data processing use the following pattern:
- download in :meth:`prepare_data`
- process and split in :meth:`setup`
However, the above are only necessary for distributed processing.
.. warning:: do not assign state in prepare_data
- :meth:
~pytorch_lightning.trainer.trainer.Trainer.fit
- :meth:
prepare_data
- :meth:
setup
Note
Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Example::
# single dataloader
def train_dataloader(self):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform,
download=True)
loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=self.batch_size,
shuffle=True
)
return loader
# multiple dataloaders, return as list
def train_dataloader(self):
mnist = MNIST(...)
cifar = CIFAR(...)
mnist_loader = torch.utils.data.DataLoader(
dataset=mnist, batch_size=self.batch_size, shuffle=True
)
cifar_loader = torch.utils.data.DataLoader(
dataset=cifar, batch_size=self.batch_size, shuffle=True
)
# each batch will be a list of tensors: [batch_mnist, batch_cifar]
return [mnist_loader, cifar_loader]
# multiple dataloader, return as dict
def train_dataloader(self):
mnist = MNIST(...)
cifar = CIFAR(...)
mnist_loader = torch.utils.data.DataLoader(
dataset=mnist, batch_size=self.batch_size, shuffle=True
)
cifar_loader = torch.utils.data.DataLoader(
dataset=cifar, batch_size=self.batch_size, shuffle=True
)
# each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar}
return {'mnist': mnist_loader, 'cifar': cifar_loader}
Source code in thunder/data/datamodule.py
def train_dataloader(self) -> DataLoader:
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
collate_fn=asr_collate,
num_workers=self.num_workers,
shuffle=True,
pin_memory=True,
)
val_dataloader(self)
Implement one or multiple PyTorch DataLoaders for validation.
The dataloader you return will not be reloaded unless you set
:paramref:~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs
to
a positive integer.
It's recommended that all data downloads and preparation happen in :meth:prepare_data
.
- :meth:
~pytorch_lightning.trainer.trainer.Trainer.fit
- :meth:
~pytorch_lightning.trainer.trainer.Trainer.validate
- :meth:
prepare_data
- :meth:
setup
Note
Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
Returns:
Type | Description |
---|---|
A |
class: |
Examples::
def val_dataloader(self):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root='/path/to/mnist/', train=False,
transform=transform, download=True)
loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=self.batch_size,
shuffle=False
)
return loader
# can also return multiple dataloaders
def val_dataloader(self):
return [loader_a, loader_b, ..., loader_n]
Note
If you don't need a validation dataset and a :meth:validation_step
, you don't need to
implement this method.
Note
In the case where you return multiple validation dataloaders, the :meth:validation_step
will have an argument dataloader_idx
which matches the order here.
Source code in thunder/data/datamodule.py
def val_dataloader(self) -> DataLoader:
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=asr_collate,
num_workers=self.num_workers,
pin_memory=True,
)
ManifestDatamodule (BaseDataModule)
Source code in thunder/data/datamodule.py
class ManifestDatamodule(BaseDataModule):
def __init__(
self,
train_manifest: str,
val_manifest: str,
test_manifest: str,
force_mono: bool = True,
sample_rate: int = 16000,
batch_size: int = 10,
num_workers: int = 8,
):
"""Datamodule compatible with the NEMO manifest data format.
Args:
train_manifest: Training manifest file
val_manifest: Validation manifest file
test_manifest: Test manifest file
force_mono: Check [`ManifestSpeechDataset`][thunder.data.dataset.ManifestSpeechDataset]
sample_rate: Check [`ManifestSpeechDataset`][thunder.data.dataset.ManifestSpeechDataset]
batch_size: Batch size used by dataloader
num_workers: Number of workers used by dataloader
"""
super().__init__(
batch_size=batch_size,
num_workers=num_workers,
)
self.manifest_mapping = {
"train": train_manifest,
"valid": val_manifest,
"test": test_manifest,
}
self.force_mono = force_mono
self.sample_rate = sample_rate
def get_dataset(self, split: str) -> ManifestSpeechDataset:
return ManifestSpeechDataset(
self.manifest_mapping[split], self.force_mono, self.sample_rate
)
__init__(self, train_manifest, val_manifest, test_manifest, force_mono=True, sample_rate=16000, batch_size=10, num_workers=8)
special
Datamodule compatible with the NEMO manifest data format.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
train_manifest |
str |
Training manifest file |
required |
val_manifest |
str |
Validation manifest file |
required |
test_manifest |
str |
Test manifest file |
required |
force_mono |
bool |
Check |
True |
sample_rate |
int |
Check |
16000 |
batch_size |
int |
Batch size used by dataloader |
10 |
num_workers |
int |
Number of workers used by dataloader |
8 |
Source code in thunder/data/datamodule.py
def __init__(
self,
train_manifest: str,
val_manifest: str,
test_manifest: str,
force_mono: bool = True,
sample_rate: int = 16000,
batch_size: int = 10,
num_workers: int = 8,
):
"""Datamodule compatible with the NEMO manifest data format.
Args:
train_manifest: Training manifest file
val_manifest: Validation manifest file
test_manifest: Test manifest file
force_mono: Check [`ManifestSpeechDataset`][thunder.data.dataset.ManifestSpeechDataset]
sample_rate: Check [`ManifestSpeechDataset`][thunder.data.dataset.ManifestSpeechDataset]
batch_size: Batch size used by dataloader
num_workers: Number of workers used by dataloader
"""
super().__init__(
batch_size=batch_size,
num_workers=num_workers,
)
self.manifest_mapping = {
"train": train_manifest,
"valid": val_manifest,
"test": test_manifest,
}
self.force_mono = force_mono
self.sample_rate = sample_rate
get_dataset(self, split)
Function to get the corresponding dataset to the specified split. This should be implemented by subclasses.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
split |
str |
One of "train", "valid" or "test". |
required |
Returns:
Type | Description |
---|---|
ManifestSpeechDataset |
The corresponding dataset. |
Source code in thunder/data/datamodule.py
def get_dataset(self, split: str) -> ManifestSpeechDataset:
return ManifestSpeechDataset(
self.manifest_mapping[split], self.force_mono, self.sample_rate
)