Back to Claude Scientific Skills

LightningDataModule - Comprehensive Guide

scientific-skills/pytorch-lightning/references/data_module.md

2.38.014.2 KB
Original Source

LightningDataModule - Comprehensive Guide

Overview

A LightningDataModule is a reusable, shareable class that encapsulates all data processing steps in PyTorch Lightning. It solves the problem of scattered data preparation logic by standardizing how datasets are managed and shared across projects.

Core Problem It Solves

In traditional PyTorch workflows, data handling is fragmented across multiple files, making it difficult to answer questions like:

  • "What splits did you use?"
  • "What transforms were applied?"
  • "How was the data prepared?"

DataModules centralize this information for reproducibility and reusability.

Five Processing Steps

A DataModule organizes data handling into five phases:

  1. Download/tokenize/process - Initial data acquisition
  2. Clean and save - Persist processed data to disk
  3. Load into Dataset - Create PyTorch Dataset objects
  4. Apply transforms - Data augmentation, normalization, etc.
  5. Wrap in DataLoader - Configure batching and loading

Main Methods

prepare_data()

Downloads and processes data. Runs only once on a single process (not distributed).

Use for:

  • Downloading datasets
  • Tokenizing text
  • Saving processed data to disk

Important: Do not set state here (e.g., self.x = y). State is not transferred to other processes.

Example:

python
def prepare_data(self):
    # Download data (runs once)
    download_dataset("http://example.com/data.zip", "data/")

    # Tokenize and save (runs once)
    tokenize_and_save("data/raw/", "data/processed/")

setup(stage)

Creates datasets and applies transforms. Runs on every process in distributed training.

Parameters:

  • stage - 'fit', 'validate', 'test', or 'predict'

Use for:

  • Creating train/val/test splits
  • Building Dataset objects
  • Applying transforms
  • Setting state (self.train_dataset = ...)

Example:

python
def setup(self, stage):
    if stage == 'fit':
        full_dataset = MyDataset("data/processed/")
        self.train_dataset, self.val_dataset = random_split(
            full_dataset, [0.8, 0.2]
        )

    if stage == 'test':
        self.test_dataset = MyDataset("data/processed/test/")

    if stage == 'predict':
        self.predict_dataset = MyDataset("data/processed/predict/")

train_dataloader()

Returns the training DataLoader.

Example:

python
def train_dataloader(self):
    return DataLoader(
        self.train_dataset,
        batch_size=self.batch_size,
        shuffle=True,
        num_workers=self.num_workers,
        pin_memory=True
    )

val_dataloader()

Returns the validation DataLoader(s).

Example:

python
def val_dataloader(self):
    return DataLoader(
        self.val_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers,
        pin_memory=True
    )

test_dataloader()

Returns the test DataLoader(s).

Example:

python
def test_dataloader(self):
    return DataLoader(
        self.test_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers
    )

predict_dataloader()

Returns the prediction DataLoader(s).

Example:

python
def predict_dataloader(self):
    return DataLoader(
        self.predict_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers
    )

Complete Example

python
import lightning as L
from torch.utils.data import DataLoader, Dataset, random_split
import torch

class MyDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        self.data = self._load_data()

    def _load_data(self):
        # Load your data here
        return torch.randn(1000, 3, 224, 224)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample

class MyDataModule(L.LightningDataModule):
    def __init__(self, data_dir="./data", batch_size=32, num_workers=4):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        # Transforms
        self.train_transform = self._get_train_transforms()
        self.test_transform = self._get_test_transforms()

    def _get_train_transforms(self):
        # Define training transforms
        return lambda x: x  # Placeholder

    def _get_test_transforms(self):
        # Define test/val transforms
        return lambda x: x  # Placeholder

    def prepare_data(self):
        # Download data (runs once on single process)
        # download_data(self.data_dir)
        pass

    def setup(self, stage=None):
        # Create datasets (runs on every process)
        if stage == 'fit' or stage is None:
            full_dataset = MyDataset(
                self.data_dir,
                transform=self.train_transform
            )
            train_size = int(0.8 * len(full_dataset))
            val_size = len(full_dataset) - train_size
            self.train_dataset, self.val_dataset = random_split(
                full_dataset, [train_size, val_size]
            )

        if stage == 'test' or stage is None:
            self.test_dataset = MyDataset(
                self.data_dir,
                transform=self.test_transform
            )

        if stage == 'predict':
            self.predict_dataset = MyDataset(
                self.data_dir,
                transform=self.test_transform
            )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True if self.num_workers > 0 else False
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True if self.num_workers > 0 else False
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers
        )

    def predict_dataloader(self):
        return DataLoader(
            self.predict_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers
        )

Usage

python
# Create DataModule
dm = MyDataModule(data_dir="./data", batch_size=64, num_workers=8)

# Use with Trainer
trainer = L.Trainer(max_epochs=10)
trainer.fit(model, datamodule=dm)

# Test
trainer.test(model, datamodule=dm)

# Predict
predictions = trainer.predict(model, datamodule=dm)

# Or use standalone in PyTorch
dm.prepare_data()
dm.setup(stage='fit')
train_loader = dm.train_dataloader()

for batch in train_loader:
    # Your training code
    pass

Additional Hooks

transfer_batch_to_device(batch, device, dataloader_idx)

Custom logic for moving batches to devices.

Example:

python
def transfer_batch_to_device(self, batch, device, dataloader_idx):
    # Custom transfer logic
    if isinstance(batch, dict):
        return {k: v.to(device) for k, v in batch.items()}
    return super().transfer_batch_to_device(batch, device, dataloader_idx)

on_before_batch_transfer(batch, dataloader_idx)

Augment or modify batch before transferring to device (runs on CPU).

Example:

python
def on_before_batch_transfer(self, batch, dataloader_idx):
    # Apply CPU-based augmentations
    batch['image'] = apply_augmentation(batch['image'])
    return batch

on_after_batch_transfer(batch, dataloader_idx)

Augment or modify batch after transferring to device (runs on GPU).

Example:

python
def on_after_batch_transfer(self, batch, dataloader_idx):
    # Apply GPU-based augmentations
    batch['image'] = gpu_augmentation(batch['image'])
    return batch

state_dict() / load_state_dict(state_dict)

Save and restore DataModule state for checkpointing.

Example:

python
def state_dict(self):
    return {"current_fold": self.current_fold}

def load_state_dict(self, state_dict):
    self.current_fold = state_dict["current_fold"]

teardown(stage)

Cleanup operations after training/testing/prediction.

Example:

python
def teardown(self, stage):
    # Clean up resources
    if stage == 'fit':
        self.train_dataset = None
        self.val_dataset = None

Advanced Patterns

Multiple Validation/Test DataLoaders

Return a list or dictionary of DataLoaders:

python
def val_dataloader(self):
    return [
        DataLoader(self.val_dataset_1, batch_size=32),
        DataLoader(self.val_dataset_2, batch_size=32)
    ]

# Or with names (for logging)
def val_dataloader(self):
    return {
        "val_easy": DataLoader(self.val_easy, batch_size=32),
        "val_hard": DataLoader(self.val_hard, batch_size=32)
    }

# In LightningModule
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    if dataloader_idx == 0:
        # Handle val_dataset_1
        pass
    else:
        # Handle val_dataset_2
        pass

Cross-Validation

python
class CrossValidationDataModule(L.LightningDataModule):
    def __init__(self, data_dir, batch_size, num_folds=5):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_folds = num_folds
        self.current_fold = 0

    def setup(self, stage=None):
        full_dataset = MyDataset(self.data_dir)
        fold_size = len(full_dataset) // self.num_folds

        # Create fold indices
        indices = list(range(len(full_dataset)))
        val_start = self.current_fold * fold_size
        val_end = val_start + fold_size

        val_indices = indices[val_start:val_end]
        train_indices = indices[:val_start] + indices[val_end:]

        self.train_dataset = Subset(full_dataset, train_indices)
        self.val_dataset = Subset(full_dataset, val_indices)

    def set_fold(self, fold):
        self.current_fold = fold

    def state_dict(self):
        return {"current_fold": self.current_fold}

    def load_state_dict(self, state_dict):
        self.current_fold = state_dict["current_fold"]

# Usage
dm = CrossValidationDataModule("./data", batch_size=32, num_folds=5)

for fold in range(5):
    dm.set_fold(fold)
    trainer = L.Trainer(max_epochs=10)
    trainer.fit(model, datamodule=dm)

Hyperparameter Saving

python
class MyDataModule(L.LightningDataModule):
    def __init__(self, data_dir, batch_size=32, num_workers=4):
        super().__init__()
        # Save hyperparameters
        self.save_hyperparameters()

    def setup(self, stage=None):
        # Access via self.hparams
        print(f"Batch size: {self.hparams.batch_size}")

Best Practices

1. Separate prepare_data and setup

  • prepare_data() - Downloads/processes (single process, no state)
  • setup() - Creates datasets (every process, set state)

2. Use stage Parameter

Check the stage in setup() to avoid unnecessary work:

python
def setup(self, stage):
    if stage == 'fit':
        # Only load train/val data when fitting
        self.train_dataset = ...
        self.val_dataset = ...
    elif stage == 'test':
        # Only load test data when testing
        self.test_dataset = ...

3. Pin Memory for GPU Training

Enable pin_memory=True in DataLoaders for faster GPU transfer:

python
def train_dataloader(self):
    return DataLoader(..., pin_memory=True)

4. Use Persistent Workers

Prevent worker restarts between epochs:

python
def train_dataloader(self):
    return DataLoader(
        ...,
        num_workers=4,
        persistent_workers=True
    )

5. Avoid Shuffle in Validation/Test

Never shuffle validation or test data:

python
def val_dataloader(self):
    return DataLoader(..., shuffle=False)  # Never True

6. Make DataModules Reusable

Accept configuration parameters in __init__:

python
class MyDataModule(L.LightningDataModule):
    def __init__(self, data_dir, batch_size, num_workers, augment=True):
        super().__init__()
        self.save_hyperparameters()

7. Document Data Structure

Add docstrings explaining data format and expectations:

python
class MyDataModule(L.LightningDataModule):
    """
    DataModule for XYZ dataset.

    Data format: (image, label) tuples
    - image: torch.Tensor of shape (C, H, W)
    - label: int in range [0, num_classes)

    Args:
        data_dir: Path to data directory
        batch_size: Batch size for dataloaders
        num_workers: Number of data loading workers
    """

Common Pitfalls

1. Setting State in prepare_data

Wrong:

python
def prepare_data(self):
    self.dataset = load_data()  # State not transferred to other processes!

Correct:

python
def prepare_data(self):
    download_data()  # Only download, no state

def setup(self, stage):
    self.dataset = load_data()  # Set state here

2. Not Using stage Parameter

Inefficient:

python
def setup(self, stage):
    self.train_dataset = load_train()
    self.val_dataset = load_val()
    self.test_dataset = load_test()  # Loads even when just fitting

Efficient:

python
def setup(self, stage):
    if stage == 'fit':
        self.train_dataset = load_train()
        self.val_dataset = load_val()
    elif stage == 'test':
        self.test_dataset = load_test()

3. Forgetting to Return DataLoaders

Wrong:

python
def train_dataloader(self):
    DataLoader(self.train_dataset, ...)  # Forgot return!

Correct:

python
def train_dataloader(self):
    return DataLoader(self.train_dataset, ...)

Integration with Trainer

python
# Initialize DataModule
dm = MyDataModule(data_dir="./data", batch_size=64)

# All data loading is handled by DataModule
trainer = L.Trainer(max_epochs=10)
trainer.fit(model, datamodule=dm)

# DataModule handles validation too
trainer.validate(model, datamodule=dm)

# And testing
trainer.test(model, datamodule=dm)

# And prediction
predictions = trainer.predict(model, datamodule=dm)