Skip to content

Datamodule

Implements pytorch lightning's Datamodule for audio datasets.

BaseDataModule (LightningDataModule)

Source code in thunder/data/datamodule.py
class BaseDataModule(LightningDataModule):
    def __init__(
        self,
        batch_size: int = 10,
        num_workers: int = 8,
    ):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers

    def get_dataset(self, split: str) -> BaseSpeechDataset:
        """Function to get the corresponding dataset to the specified split.
        This should be implemented by subclasses.

        Args:
            split: One of "train", "valid" or "test".

        Returns:
            The corresponding dataset.
        """
        raise NotImplementedError()

    def setup(self, stage: Optional[str] = None):
        if stage in (None, "fit"):
            self.train_dataset = self.get_dataset(split="train")
            self.val_dataset = self.get_dataset(split="valid")
        if stage in (None, "test"):
            self.test_dataset = self.get_dataset(split="test")

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            collate_fn=asr_collate,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=asr_collate,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=asr_collate,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    @property
    def steps_per_epoch(self) -> int:
        """Number of steps for each training epoch. Used for learning rate scheduling.

        Returns:
            Number of steps
        """
        return len(self.train_dataset) // self.batch_size

steps_per_epoch: int property readonly

Number of steps for each training epoch. Used for learning rate scheduling.

Returns:

Type Description
int

Number of steps

get_dataset(self, split)

Function to get the corresponding dataset to the specified split. This should be implemented by subclasses.

Parameters:

Name Type Description Default
split str

One of "train", "valid" or "test".

required

Returns:

Type Description
BaseSpeechDataset

The corresponding dataset.

Source code in thunder/data/datamodule.py
def get_dataset(self, split: str) -> BaseSpeechDataset:
    """Function to get the corresponding dataset to the specified split.
    This should be implemented by subclasses.

    Args:
        split: One of "train", "valid" or "test".

    Returns:
        The corresponding dataset.
    """
    raise NotImplementedError()

setup(self, stage=None)

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

Name Type Description Default
stage Optional[str]

either 'fit', 'validate', 'test', or 'predict'

None

Example::

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
Source code in thunder/data/datamodule.py
def setup(self, stage: Optional[str] = None):
    if stage in (None, "fit"):
        self.train_dataset = self.get_dataset(split="train")
        self.val_dataset = self.get_dataset(split="valid")
    if stage in (None, "test"):
        self.test_dataset = self.get_dataset(split="test")

test_dataloader(self)

Implement one or multiple PyTorch DataLoaders for testing.

For data processing use the following pattern:

- download in :meth:`prepare_data`
- process and split in :meth:`setup`

However, the above are only necessary for distributed processing.

.. warning:: do not assign state in prepare_data

  • :meth:~pytorch_lightning.trainer.trainer.Trainer.test
  • :meth:prepare_data
  • :meth:setup

Note

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Returns:

Type Description
A

class:torch.utils.data.DataLoader or a sequence of them specifying testing samples.

Example::

def test_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform,
                    download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def test_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]

Note

If you don't need a test dataset and a :meth:test_step, you don't need to implement this method.

Note

In the case where you return multiple test dataloaders, the :meth:test_step will have an argument dataloader_idx which matches the order here.

Source code in thunder/data/datamodule.py
def test_dataloader(self) -> DataLoader:
    return DataLoader(
        self.test_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        collate_fn=asr_collate,
        num_workers=self.num_workers,
        pin_memory=True,
    )

train_dataloader(self)

Implement one or more PyTorch DataLoaders for training.

Returns:

Type Description
A collection of

class:torch.utils.data.DataLoader specifying training samples. In the case of multiple dataloaders, please see this :ref:section <multiple-dataloaders>.

The dataloader you return will not be reloaded unless you set :paramref:~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs to a positive integer.

For data processing use the following pattern:

- download in :meth:`prepare_data`
- process and split in :meth:`setup`

However, the above are only necessary for distributed processing.

.. warning:: do not assign state in prepare_data

  • :meth:~pytorch_lightning.trainer.trainer.Trainer.fit
  • :meth:prepare_data
  • :meth:setup

Note

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Example::

# single dataloader
def train_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform,
                    download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=True
    )
    return loader

# multiple dataloaders, return as list
def train_dataloader(self):
    mnist = MNIST(...)
    cifar = CIFAR(...)
    mnist_loader = torch.utils.data.DataLoader(
        dataset=mnist, batch_size=self.batch_size, shuffle=True
    )
    cifar_loader = torch.utils.data.DataLoader(
        dataset=cifar, batch_size=self.batch_size, shuffle=True
    )
    # each batch will be a list of tensors: [batch_mnist, batch_cifar]
    return [mnist_loader, cifar_loader]

# multiple dataloader, return as dict
def train_dataloader(self):
    mnist = MNIST(...)
    cifar = CIFAR(...)
    mnist_loader = torch.utils.data.DataLoader(
        dataset=mnist, batch_size=self.batch_size, shuffle=True
    )
    cifar_loader = torch.utils.data.DataLoader(
        dataset=cifar, batch_size=self.batch_size, shuffle=True
    )
    # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar}
    return {'mnist': mnist_loader, 'cifar': cifar_loader}
Source code in thunder/data/datamodule.py
def train_dataloader(self) -> DataLoader:
    return DataLoader(
        self.train_dataset,
        batch_size=self.batch_size,
        collate_fn=asr_collate,
        num_workers=self.num_workers,
        shuffle=True,
        pin_memory=True,
    )

val_dataloader(self)

Implement one or multiple PyTorch DataLoaders for validation.

The dataloader you return will not be reloaded unless you set :paramref:~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs to a positive integer.

It's recommended that all data downloads and preparation happen in :meth:prepare_data.

  • :meth:~pytorch_lightning.trainer.trainer.Trainer.fit
  • :meth:~pytorch_lightning.trainer.trainer.Trainer.validate
  • :meth:prepare_data
  • :meth:setup

Note

Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Returns:

Type Description
A

class:torch.utils.data.DataLoader or a sequence of them specifying validation samples.

Examples::

def val_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False,
                    transform=transform, download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def val_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]

Note

If you don't need a validation dataset and a :meth:validation_step, you don't need to implement this method.

Note

In the case where you return multiple validation dataloaders, the :meth:validation_step will have an argument dataloader_idx which matches the order here.

Source code in thunder/data/datamodule.py
def val_dataloader(self) -> DataLoader:
    return DataLoader(
        self.val_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        collate_fn=asr_collate,
        num_workers=self.num_workers,
        pin_memory=True,
    )

ManifestDatamodule (BaseDataModule)

Source code in thunder/data/datamodule.py
class ManifestDatamodule(BaseDataModule):
    def __init__(
        self,
        train_manifest: str,
        val_manifest: str,
        test_manifest: str,
        force_mono: bool = True,
        sample_rate: int = 16000,
        batch_size: int = 10,
        num_workers: int = 8,
    ):
        """Datamodule compatible with the NEMO manifest data format.

        Args:
            train_manifest: Training manifest file
            val_manifest: Validation manifest file
            test_manifest: Test manifest file
            force_mono: Check [`ManifestSpeechDataset`][thunder.data.dataset.ManifestSpeechDataset]
            sample_rate: Check [`ManifestSpeechDataset`][thunder.data.dataset.ManifestSpeechDataset]
            batch_size: Batch size used by dataloader
            num_workers: Number of workers used by dataloader
        """
        super().__init__(
            batch_size=batch_size,
            num_workers=num_workers,
        )
        self.manifest_mapping = {
            "train": train_manifest,
            "valid": val_manifest,
            "test": test_manifest,
        }
        self.force_mono = force_mono
        self.sample_rate = sample_rate

    def get_dataset(self, split: str) -> ManifestSpeechDataset:
        return ManifestSpeechDataset(
            self.manifest_mapping[split], self.force_mono, self.sample_rate
        )

__init__(self, train_manifest, val_manifest, test_manifest, force_mono=True, sample_rate=16000, batch_size=10, num_workers=8) special

Datamodule compatible with the NEMO manifest data format.

Parameters:

Name Type Description Default
train_manifest str

Training manifest file

required
val_manifest str

Validation manifest file

required
test_manifest str

Test manifest file

required
force_mono bool

Check ManifestSpeechDataset

True
sample_rate int

Check ManifestSpeechDataset

16000
batch_size int

Batch size used by dataloader

10
num_workers int

Number of workers used by dataloader

8
Source code in thunder/data/datamodule.py
def __init__(
    self,
    train_manifest: str,
    val_manifest: str,
    test_manifest: str,
    force_mono: bool = True,
    sample_rate: int = 16000,
    batch_size: int = 10,
    num_workers: int = 8,
):
    """Datamodule compatible with the NEMO manifest data format.

    Args:
        train_manifest: Training manifest file
        val_manifest: Validation manifest file
        test_manifest: Test manifest file
        force_mono: Check [`ManifestSpeechDataset`][thunder.data.dataset.ManifestSpeechDataset]
        sample_rate: Check [`ManifestSpeechDataset`][thunder.data.dataset.ManifestSpeechDataset]
        batch_size: Batch size used by dataloader
        num_workers: Number of workers used by dataloader
    """
    super().__init__(
        batch_size=batch_size,
        num_workers=num_workers,
    )
    self.manifest_mapping = {
        "train": train_manifest,
        "valid": val_manifest,
        "test": test_manifest,
    }
    self.force_mono = force_mono
    self.sample_rate = sample_rate

get_dataset(self, split)

Function to get the corresponding dataset to the specified split. This should be implemented by subclasses.

Parameters:

Name Type Description Default
split str

One of "train", "valid" or "test".

required

Returns:

Type Description
ManifestSpeechDataset

The corresponding dataset.

Source code in thunder/data/datamodule.py
def get_dataset(self, split: str) -> ManifestSpeechDataset:
    return ManifestSpeechDataset(
        self.manifest_mapping[split], self.force_mono, self.sample_rate
    )