doc/source/train/tutorials/content/getting-started/README.md
This notebook shows how to run distributed data-parallel training with PyTorch on an Anyscale cluster using Ray Train. You train a ResNet-18 model on MNIST across multiple GPUs, with built-in support for checkpointing, metrics reporting, and distributed orchestration.
prepare_model() and prepare_data_loader() for multi-GPU executionScalingConfig(num_workers=..., use_gpu=True) and track outputs with RunConfig(storage_path=...)ray.train.report(...), with best practices for rank-0 checkpointingThe entire workflow runs fully distributed from the start: you define your training loop once, and Ray handles orchestration, sharding, and checkpointing across the cluster.
Use Ray Train when you face one of the following challenges:
| Challenge | Detail | Solution |
|---|---|---|
| Need to speed up or scale up training | Training jobs might take a long time to complete, or require a lot of compute | Ray Train provides a distributed training framework that scales training to multiple GPUs |
| Minimize overhead of setting up distributed training | You must manage the underlying infrastructure | Ray Train handles the underlying infrastructure through Ray's autoscaling |
| Achieve observability | You need to connect to different nodes and GPUs to find the root cause of failures, fetch logs, traces, etc | Ray Train provides observability through Ray's dashboard, metrics, and traces that let you monitor the training job |
| Ensure reliable training | Training jobs can fail due to hardware failures, network issues, or other unexpected events | Ray Train provides fault tolerance through checkpointing, automatic retries, and the ability to resume training from the last checkpoint |
| Avoid significant code rewrite | You might need to fully rewrite your training loop to support distributed training | Ray Train provides a suite of integrations with the PyTorch ecosystem, Tree-based methods (XGB, LGBM), and more to minimize the amount of code changes needed |
The preceding diagram shows the lifecycle of a single training step in PyTorch DistributedDataParallel (DDP) when orchestrated by Ray Train:
Model Replication
Ray Train initializes the model on GPU rank 0 and broadcasts it to all other workers so that each has an identical copy.
Sharded Data Loading
Ray Train automatically splits the dataset into non-overlapping shards. Each worker processes only its shard, ensuring efficient parallelism without duplicate samples.
Forward & Backward Passes
Each worker runs a forward pass and computes gradients locally during the backward pass.
Gradient Synchronization
An AllReduce step aggregates gradients across workers, ensuring that model updates stay consistent across all GPUs.
Weight Updates
Once Ray Train synchronizes gradients, each worker applies the update, keeping model replicas in sync.
Checkpointing & Metrics
By convention, only the rank 0 worker saves checkpoints and logs metrics to persistent storage. This avoids duplication while preserving progress and results.
With Ray Train, you don't need to manage process groups or samplers manually—utilities like prepare_model() and prepare_data_loader() wrap these details so your code works out of the box in a distributed setting.
| Schematic overview of DistributedDataParallel (DDP) training: (1) Ray Train replicates the model from the <code>GPU rank 0</code> to all other workers; (2) each worker receives a shard of the dataset and processes a mini-batch; (3) during the backward pass, Ray Train averages gradients across GPUs; (4) checkpoint and metrics from rank 0 GPU save to the persistent storage. |
Start by importing all the libraries you need for this tutorial.
os, datetime, tempfile, csv, shutil, and gc help with file paths, checkpointing, cleanup, and general housekeeping.pandas, numpy, matplotlib, and PIL help you inspect the dataset and plot sample images.torch, CrossEntropyLoss, Adam) plus torchvision for loading MNIST and building a ResNet-18 model.ScalingConfig, RunConfig, and TorchTrainer. These components handle cluster scaling, experiment output storage, and execution of your training loop across GPUs.This notebook assumes Ray is already running (for example, inside an Anyscale cluster), so you don't need to call ray.init() manually.
# 00. Runtime setup — install same deps as build.sh and set env vars
import os
import sys
import subprocess
# Non-secret env var
os.environ["RAY_TRAIN_V2_ENABLED"] = "1"
# Install Python dependencies (
!pip install --no-cache-dir torch==2.8.0 torchvision==0.23.0
# 01. Imports
# --- Standard library: file IO, paths, timestamps, temp dirs, cleanup ---
import csv # Simple CSV logging for metrics in single-GPU section
import datetime # Timestamps for run directories / filenames
import os # Filesystem utilities (paths, env vars)
import tempfile # Ephemeral dirs for checkpoint staging with ray.train.report()
import shutil # Cleanup of artifacts (later cells)
import gc # Manual garbage collection to cleanup after inference
from pathlib import Path # Convenient, cross-platform path handling
# --- Visualization & data wrangling ---
import matplotlib.pyplot as plt # Plot sample digits and metrics curves
from PIL import Image # Image utilities for inspection/debug
import numpy as np # Numeric helpers (random sampling, arrays)
import pandas as pd # Read metrics.csv into a DataFrame
# --- PyTorch & TorchVision (model + dataset) ---
import torch
from torch.nn import CrossEntropyLoss # Classification loss for MNIST
from torch.optim import Adam # Optimizer
from torchvision.models import resnet18 # Baseline CNN (we’ll adapt for 1-channel input)
from torchvision.datasets import MNIST # Dataset
from torchvision.transforms import ToTensor, Normalize, Compose # Preprocessing pipeline
# --- Ray Train (distributed orchestration) ---
import ray
from ray.train import ScalingConfig, RunConfig # Configure scale and storage
from ray.train.torch import TorchTrainer # Multi-GPU PyTorch trainer (DDP/FSDP)
Next, download the MNIST dataset using torchvision.datasets.MNIST.
./data directory.train=True, this loads the training split of the dataset.After downloading, wrap this dataset in a DataLoader and apply normalization for use in model training.
# 02. Download MNIST Dataset
dataset = MNIST(root="/mnt/cluster_storage/data", train=True, download=True)
<b>Note about Anyscale storage options</b>
In this example, this tutorial stores the MNIST dataset under <code>/mnt/cluster_storage/</code>, which is Anyscale’s persistent cluster storage.
Unlike node-local NVMe volumes, cluster storage is shared across nodes in your cluster.
Data written here persists across cluster restarts, making it a safe place for datasets, checkpoints, and results.
This is the recommended location for training data and artifacts you want to reuse.
Anyscale also provides each node with its own volume and disk and doesn’t share them with other nodes.
Local storage is very fast - Anyscale supports the Non-Volatile Memory Express (NVMe) interface.
Local storage isn't a persistent storage, Anyscale deletes data in the local storage after instance termination.
Read more about available <a href="https://docs.anyscale.com/configuration/storage" target="_blank">storage</a> options.
</div>Before training, take a quick look at the dataset.
# 03. Visualize Sample Digits
# Create a square figure for plotting 9 samples (3x3 grid)
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
# Loop through grid slots and plot a random digit each time
for i in range(1, cols * rows + 1):
# Randomly select an index from the dataset
sample_idx = np.random.randint(0, len(dataset.data))
img, label = dataset[sample_idx] # image (PIL) and its digit label
# Add subplot to the figure
figure.add_subplot(rows, cols, i)
plt.title(label) # show the digit label above each subplot
plt.axis("off") # remove axes for cleaner visualization
plt.imshow(img, cmap="gray") # display as grayscale image
Now define the ResNet-18 architecture to use for classification.
torchvision.models.resnet18 is pre-configured for 3-channel RGB input and ImageNet classes.conv1) to accept in_channels=1.num_classes=10).This gives a ResNet-18 tailored for MNIST while preserving the rest of the architecture.
# 04. Define ResNet-18 Model for MNIST
def build_resnet18():
# Start with a torchvision ResNet-18 backbone
# Set num_classes=10 since MNIST has digits 0–9
model = resnet18(num_classes=10)
# Override the first convolution layer:
# - Default expects 3 channels (RGB images)
# - MNIST is grayscale → only 1 channel
# - Keep kernel size/stride/padding consistent with original ResNet
model.conv1 = torch.nn.Conv2d(
in_channels=1, # input = grayscale
out_channels=64, # number of filters remains the same as original ResNet
kernel_size=(7, 7),
stride=(2, 2),
padding=(3, 3),
bias=False,
)
# Return the customized ResNet-18
return model
<b>Migration roadmap: from standalone PyTorch to PyTorch with Ray Train</b>
The following are the steps to take a regular PyTorch training loop and run it in a fully distributed setup with Ray Train.
<ol> <li><b>Configure scale and GPUs</b> — decide how many workers and whether each should use a GPU.</li> <li><b>Wrap the model with Ray Train</b> — use <code>prepare_model()</code> to move the ResNet to the right device and wrap it in DDP automatically.</li> <li><b>Wrap the dataset with Ray Train</b> — use <code>prepare_data_loader()</code> so each worker gets a distinct shard of MNIST, moved to the correct device.</li> <li><b>Add metrics & checkpointing</b> — report training loss and save checkpoints with <code>ray.train.report()</code> from rank-0.</li> <li><b>Configure persistent storage</b> — store outputs under <code>/mnt/cluster_storage/</code> so that results and checkpoints are available across the cluster.</li> </ol> </div>Ray Train is built around four key concepts:
train_loop_ray_train): A Python function that contains your model training logic.| High-level architecture of how Ray Train |
This is the per-worker training function that Ray executes on each process/GPU. It keeps your PyTorch code intact while Ray handles process launch, device placement, and data sharding.
Key points:
config: pass hyperparameters like num_epochs and a global_batch_size.load_model_ray_train() returns a model already wrapped by Ray Train (DDP + correct device). Use Adam and CrossEntropyLoss for MNIST.per_worker_batch = global_batch_size // world_size.build_data_loader_ray_train(...) returns a DataLoader wrapped with a DistributedSampler; each worker sees a disjoint shard.data_loader.sampler.set_epoch(epoch) ensures proper shuffling across epochs in distributed mode.zero_grad → backward → step.print_metrics_ray_train(...) logs loss; save_checkpoint_and_metrics_ray_train(...) calls ray.train.report(...) (rank-0 saves the checkpoint).Pass this function to TorchTrainer, which runs it concurrently on all workers.
See how this data-parallel training loop looks like with Ray Train and PyTorch.
# 05. Define the Ray Train per-worker training loop
def train_loop_ray_train(config: dict): # pass in hyperparameters in config
# config holds hyperparameters passed from TorchTrainer (e.g. num_epochs, global_batch_size)
# Define loss function for MNIST classification
criterion = CrossEntropyLoss()
# Build and prepare the model for distributed training.
# load_model_ray_train() calls ray.train.torch.prepare_model()
# → moves model to GPU and wraps it in DistributedDataParallel (DDP).
model = load_model_ray_train()
# Standard optimizer (learning rate fixed for demo)
optimizer = Adam(model.parameters(), lr=1e-5)
# Calculate the batch size for each worker
global_batch_size = config["global_batch_size"]
world_size = ray.train.get_context().get_world_size() # total # of workers in the job
batch_size = global_batch_size // world_size # split global batch evenly
print(f"{world_size=}\n{batch_size=}")
# Wrap DataLoader with prepare_data_loader()
# → applies DistributedSampler (shards data across workers)
# → ensures batches are automatically moved to correct device
data_loader = build_data_loader_ray_train(batch_size=batch_size)
# ----------------------- Training loop ----------------------- #
for epoch in range(config["num_epochs"]):
# Ensure each worker shuffles its shard differently every epoch
data_loader.sampler.set_epoch(epoch)
# Iterate over batches (sharded across workers)
for images, labels in data_loader:
outputs = model(images) # forward pass
loss = criterion(outputs, labels) # compute loss
optimizer.zero_grad() # reset gradients
loss.backward() # backward pass (grads averaged across workers through DDP)
optimizer.step() # update model weights
# After each epoch: report loss and log metrics
metrics = print_metrics_ray_train(loss, epoch)
# Save checkpoint (only rank-0 worker persists the model)
save_checkpoint_and_metrics_ray_train(model, metrics)
<b>Main training loop</b>
<ul> <li><strong>`global_batch_size`</strong>: the total number of samples processed in a single training step of the entire training job. <ul> <li>It's estimated like this: <code>batch size * DDP workers * gradient accumulation steps</code>.</li> </ul> </li> <li>Notice that images and labels are no longer manually moved to device (<code>images.to("cuda")</code>). This is done by <a href="https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.prepare_data_loader.html#ray-train-torch-prepare-data-loader" target="_blank"> prepare_data_loader() </a>. </li> <li>Configuration passed here is defined below. Passed to the Ray Train's <a href="https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.TorchTrainer.html#ray-train-torch-torchtrainer" target="_blank">TorchTrainer</a>.</li> <li> <a href="https://docs.ray.io/en/latest/train/api/doc/ray.train.v2.api.context.TrainContext.html#ray-train-v2-api-context-traincontext" target="_blank"> TrainContext </a> lets you get useful information about the training such as node rank, world size, world rank, experiment name. </li> <li><code>load_model_ray_train</code> and <code>build_data_loader_ray_train</code> are implemented below.</li> </ul> </div>train_loop_configThe train_loop_config is a simple dictionary of hyperparameters that Ray passes into your training loop (train_loop_ray_train).
TorchTrainer and your per-worker training code.config argument of train_loop_ray_train.In this example:
num_epochs → how many full passes through the dataset to run.global_batch_size → the total batch size across all workers (Ray splits this evenly across GPUs).You can add other parameters here (like learning_rate, embedding_dim, etc.) and they can automatically be accessible in your training loop through config["param_name"].
# 06. Define the configuration dictionary passed into the training loop
# train_loop_config is provided to TorchTrainer and injected into
# train_loop_ray_train(config) as the "config" argument.
# → Any values defined here are accessible inside the training loop.
train_loop_config = {
"num_epochs": 2, # Number of full passes through the dataset
"global_batch_size": 128 # Effective batch size across ALL workers
# (Ray will split this evenly per worker, e.g.
# with 8 workers → 16 samples/worker/step)
}
ScalingConfigThe ScalingConfig tells Ray Train how many workers to launch and what resources each worker should use.
num_workers=8 → Run the training loop on 8 parallel workers. Each worker runs the same code on a different shard of the data.use_gpu=True → Assign one GPU per worker. If you set this to False, each worker would train on CPU instead.This declarative config is what allows Ray to handle cluster orchestration for you—you don't need to manually start processes or set CUDA devices.
Later, you can pass this scaling_config into the TorchTrainer to launch distributed training.
# 07. Configure the scaling of the training job
# ScalingConfig defines how many parallel training workers Ray should launch
# and whether each worker should be assigned a GPU or CPU.Z
# → Each worker runs train_loop_ray_train(config) independently,
# with Ray handling synchronization through DDP under the hood.
scaling_config = ScalingConfig(
num_workers=8, # Launch 8 training workers (1 process per worker)
use_gpu=True # Allocate 1 GPU to each worker
)
Docs on <a href="https://docs.ray.io/en/latest/train/api/doc/ray.train.ScalingConfig.html#ray-train-scalingconfig" target="_blank">ScalingConfig</a> can be found with the link in this sentence.
See docs on configuring <a href="https://docs.ray.io/en/latest/train/user-guides/using-gpus.html" target="_blank">scale and GPUs</a> for more details.
</div>prepare_model()Next, define a helper function to build and prepare the model for Ray Train.
build_resnet18().model.to("cuda") and wrapping it in DistributedDataParallel (DDP), use ray.train.torch.prepare_model().
This means the same code works whether you're training on 1 GPU or 100 GPUs—no manual device placement or DDP boilerplate required.
# 08. Build and prepare the model for Ray Train
def load_model_ray_train() -> torch.nn.Module:
model = build_resnet18()
# prepare_model() → move to correct device + wrap in DDP automatically
model = ray.train.torch.prepare_model(model)
return model
prepare_data_loader()Now define a helper that builds the MNIST DataLoader and makes it Ray Train–ready.
Apply standard preprocessing:
ToTensor() → convert PIL images to PyTorch tensorsNormalize((0.5,), (0.5,)) → center and scale pixel valuesConstruct a PyTorch DataLoader with batching and shuffling.
Finally, wrap it with prepare_data_loader(), which automatically:
DistributedSampler when running with multiple workers, so that each worker processes a unique shard of the dataset.This utility lets you use the same DataLoader code whether you're training on one GPU or many—Ray handles the distributed sharding and device placement for you.
# 09. Build a Ray Train–ready DataLoader for MNIST
def build_data_loader_ray_train(batch_size: int) -> torch.utils.data.DataLoader:
# Define preprocessing: convert to tensor + normalize pixel values
transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
# Load the MNIST training set from persistent cluster storage
train_data = MNIST(
root="/mnt/cluster_storage/data",
train=True,
download=True,
transform=transform,
)
# Standard PyTorch DataLoader (batching, shuffling, drop last incomplete batch)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
# prepare_data_loader():
# - Adds a DistributedSampler when using multiple workers
# - Moves batches to the correct device automatically
train_loader = ray.train.torch.prepare_data_loader(train_loader)
return train_loader
<b>Ray Data integration</b>
This step isn't necessary if you are integrating your Ray Train workload with Ray Data. It's especially useful if preprocessing is CPU-heavly and user wants to run preprocessing and training of separate instances.
</div>During training, it's important to log metrics like loss values so you can monitor progress.
This helper function prints metrics from every worker:
ray.train.get_context().get_world_rank() to identify which worker is reporting.# 10. Report training metrics from each worker
def print_metrics_ray_train(loss: torch.Tensor, epoch: int) -> None:
metrics = {"loss": loss.item(), "epoch": epoch}
world_rank = ray.train.get_context().get_world_rank() # report from all workers
print(f"{metrics=} {world_rank=}")
return metrics
If you want to log only from the rank 0 worker, use this code:
def print_metrics_ray_train(loss: torch.Tensor, epoch: int) -> None:
metrics = {"loss": loss.item(), "epoch": epoch}
if ray.train.get_context().get_world_rank() == 0: # report only from the rank 0 worker
print(f"{metrics=} {world_rank=}")
return metrics
Report intermediate metrics and checkpoints using the ray.train.report utility function.
This helper function:
torch.save().
model.module.state_dict() to unwrap it.ray.train.report() to:
Checkpoint object created from the staged directory.This way, each epoch produces both metrics for monitoring and a checkpoint for recovery or inference.
# 11. Save checkpoint and report metrics with Ray Train
def save_checkpoint_and_metrics_ray_train(model: torch.nn.Module, metrics: dict[str, float]) -> None:
# Create a temporary directory to stage checkpoint files
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
# Save the model weights.
# Note: under DDP the model is wrapped in DistributedDataParallel,
# so we unwrap it with `.module` before calling state_dict().
torch.save(
model.module.state_dict(), # note the `.module` to unwrap the DistributedDataParallel
os.path.join(temp_checkpoint_dir, "model.pt"),
)
# Report metrics and attach a checkpoint to Ray Train.
# → metrics are logged centrally
# → checkpoint allows resuming training or running inference later
ray.train.report(
metrics,
checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
)
The preceding diagram shows how a checkpoint moves from local storage (temporary directory on a worker) to persistent cluster or cloud storage.
Key points to remember:
ray.train.report on all workers to keep the training loop synchronized.To avoid redundant writes, update the checkpointing function so that only the rank-0 worker saves the model weights.
ray.train.get_context().get_world_rank() ensures that only worker 0 performs the checkpointing.ray.train.report, but only rank-0 attaches the actual checkpoint. This keeps the training loop synchronized.This pattern is the recommended best practice:
# 12. Save checkpoint only from the rank-0 worker
def save_checkpoint_and_metrics_ray_train(model: torch.nn.Module, metrics: dict[str, float]) -> None:
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
checkpoint = None
# Only the rank-0 worker writes the checkpoint file
if ray.train.get_context().get_world_rank() == 0:
torch.save(
model.module.state_dict(), # unwrap DDP before saving
os.path.join(temp_checkpoint_dir, "model.pt"),
)
checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
# All workers still call ray.train.report()
# → keeps training loop synchronized
# → metrics are logged from each worker
# → only rank-0 attaches a checkpoint
ray.train.report(
metrics,
checkpoint=checkpoint,
)
Check the guide on saving and loading checkpoints for more details and best practices.
RunConfigTo tell Ray Train where to store results, checkpoints, and logs, use a RunConfig.
storage_path → Base directory for all outputs of this training run.
/mnt/cluster_storage/training/, which is persistent shared storage across all nodes.name → A human-readable name for the run (for example, "distributed-mnist-resnet18"). This is used to namespace output files.Together, the RunConfig defines how Ray organizes and persists all artifacts from your training job.
# 13. Configure persistent storage and run name
storage_path = "/mnt/cluster_storage/training/"
run_config = RunConfig(
storage_path=storage_path, # where to store checkpoints/logs
name="distributed-mnist-resnet18" # identifier for this run
)
This diagram shows the same DDP workflow as before, but now with Ray Train utilities highlighted:
ray.train.torch.prepare_data_loader()
DistributedSampler.ray.train.torch.prepare_model()
DistributedDataParallel (DDP) so gradients are synchronized across workers..to("cuda") calls or DDP boilerplate.ray.train.report()
By combining these helpers, Ray Train takes care of the data sharding, model replication, gradient synchronization, and checkpoint lifecycle—letting you keep your training loop clean and close to standard PyTorch.
TorchTrainerNow bring everything together with a TorchTrainer.
The TorchTrainer is the high-level Ray Train class that:
train_loop_ray_train) across the cluster.scaling_config (number of workers, GPUs/CPUs).run_config to decide where results and checkpoints are stored.train_loop_config (hyperparameters like num_epochs and global_batch_size) into the training loop.This object encapsulates the distributed orchestration, so you can start training with a call to trainer.fit().
# 14. Set up the TorchTrainer
trainer = TorchTrainer(
train_loop_ray_train, # training loop to run on each worker
scaling_config=scaling_config, # number of workers and resource config
run_config=run_config, # storage path + run name for artifacts
train_loop_config=train_loop_config, # hyperparameters passed to the loop
)
trainer.fit()Calling trainer.fit() starts the distributed training job and blocks until it completes.
When the job launches, you can see logs that confirm:
world_rank=0 and world_rank=1).cuda:0) and wraps it in DistributedDataParallel (DDP).These logs are a quick sanity check that Ray Train is correctly orchestrating multi-GPU training across your cluster.
# 15. Launch distributed training
# trainer.fit() starts the training job:
# - Spawns workers according to scaling_config
# - Runs train_loop_ray_train() on each worker
# - Collects metrics and checkpoints into result
result = trainer.fit()
When trainer.fit() finishes, it returns a Result object.
This object contains:
result.metrics_dataframe).In the preceding output, you can see:
/mnt/cluster_storage/training/distributed-mnist-resnet18/...).This makes it easy to both analyze training performance and restore the trained model later for inference.
# 16. Show the training results
result # contains metrics, checkpoints, and run history
The Result object also includes a metrics_dataframe, which stores the full history of metrics reported during training.
loss, epoch).In the example below, you can see the training loss steadily decreasing across two epochs.
# 17. Display the full metrics history as a pandas DataFrame
result.metrics_dataframe
After training, you can reload the model and use it for predictions.
Define a Ray actor (ModelWorker) that loads the checkpointed ResNet-18 onto a GPU and serves inference requests.
Initialization (__init__):
checkpoint.as_directory().Prediction (predict):
[C,H,W]) or a batch ([B,C,H,W]).torch.inference_mode() for efficiency.Finally, launch the actor with ModelWorker.remote(result.checkpoint).
This spawns a dedicated process with 1 GPU attached that can serve predictions using the trained model.
# 18. Define a Ray actor to load the trained model and run inference
@ray.remote(num_gpus=1) # allocate 1 GPU to this actor
class ModelWorker:
def __init__(self, checkpoint):
# Load model weights from the Ray checkpoint (on CPU first)
with checkpoint.as_directory() as ckpt_dir:
model_path = os.path.join(ckpt_dir, "model.pt")
state_dict = torch.load(
model_path,
map_location=torch.device("cpu"),
weights_only=True,
)
# Rebuild the model, load weights, move to GPU, and set to eval mode
self.model = build_resnet18()
self.model.load_state_dict(state_dict)
self.model.to("cuda")
self.model.eval()
@torch.inference_mode() # disable autograd for faster inference
def predict(self, batch):
"""
batch: torch.Tensor or numpy array with shape [B,C,H,W] or [C,H,W]
returns: list[int] predicted class indices
"""
x = torch.as_tensor(batch)
if x.ndim == 3: # single image → add batch dimension
x = x.unsqueeze(0) # shape becomes [1,C,H,W]
x = x.to("cuda", non_blocking=True)
logits = self.model(x)
preds = torch.argmax(logits, dim=1)
return preds.detach().cpu().tolist()
# Create a fresh actor instance (avoid naming conflicts)
worker = ModelWorker.remote(result.checkpoint)
With the ModelWorker actor running on GPU, you can now generate predictions on random samples from the MNIST dataset and plot them.
Steps in this cell:
Normalization on CPU
ToTensor().0.5 mean / std).Prediction on GPU with Actor
[1, C, H, W].ModelWorker for inference.ray.get(worker.predict.remote(x)) retrieves the predicted class index.Plot Results
This demonstrates a practical workflow: CPU-based preprocessing + GPU-based inference in a Ray actor.
# 19. CPU preprocessing + GPU inference using a Ray actor
to_tensor = ToTensor()
def normalize_cpu(img):
# Convert image (PIL) to tensor on CPU → shape [C,H,W]
t = to_tensor(img) # [C,H,W] on CPU
C = t.shape[0]
# Apply channel-wise normalization (grayscale vs RGB)
if C == 3:
norm = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
else:
norm = Normalize((0.5,), (0.5,))
return norm(t)
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
# Plot a 3x3 grid of random MNIST samples with predictions
for i in range(1, cols * rows + 1):
idx = np.random.randint(0, len(dataset))
img, label = dataset[idx]
# Preprocess on CPU, add batch dim → [1,C,H,W]
x = normalize_cpu(img).unsqueeze(0)
# Run inference on GPU using the Ray actor, fetch result
pred = ray.get(worker.predict.remote(x))[0] # int
# Plot image with true label and predicted label
figure.add_subplot(rows, cols, i)
plt.title(f"label: {label}; pred: {int(pred)}")
plt.axis("off")
arr = np.array(img)
plt.imshow(arr, cmap="gray" if arr.ndim == 2 else None)
plt.tight_layout()
plt.show()
Once you're done running inference, it's a good practice to free up resources:
ray.kill(worker, no_restart=True) → stops the ModelWorker actor and releases its GPU.del worker + gc.collect() → drop local references so Python's garbage collector can clean up.This ensures the GPU is no longer pinned by the actor and becomes available for other jobs.
# 20.
# stop the actor process and free its GPU
ray.kill(worker, no_restart=True)
# drop local references so nothing pins it
del worker
# Forcing garbage collection is optional:
# - Cluster resources are already freed by ray.kill()
# - Python will clean up the local handle eventually
# - gc.collect() is usually unnecessary unless debugging memory issues
gc.collect()
In this module you extend distributed training with Ray Train by adding Ray Data to the pipeline. Instead of relying on a local PyTorch DataLoader, you stream batches directly from a distributed Ray Dataset, enabling scalable preprocessing and just-in-time data loading across the cluster.
DataLoader with iter_torch_batches() to stream batches into your training loop.TorchTrainer with the datasets parameter.With Ray Data, you can scale preprocessing and training independently—CPUs handle input pipelines, GPUs focus on training—ensuring higher utilization and throughput in your distributed workloads.
Note that the code blocks for this module depends on the previous module, Introduction to Ray Train.
Use both Ray Train and Ray Data when you face one of the following challenges:
| Challenge | Detail | Solution |
|---|---|---|
| Need to perform online or just-in-time data processing | The training pipeline requires processing data on the fly, such as data augmentation, normalization, or other transformations that may differ for each training epoch. | Ray Train's integration with Ray Data makes it easy to implement just-in-time data processing. |
| Need to improve hardware utilization | Training and data processing need to be scaled independently to keep GPUs fully utilized, especially when preprocessing is CPU-intensive. | Ray Data can distribute data processing across multiple CPU nodes, while Ray Train runs the training loop on GPUs. |
| Need a consistent interface for loading data | The training process may need to load data from various sources, such as Parquet, CSV, or lakehouses. | Ray Data provides a consistent interface for loading, shuffling, sharding, and batching data for training loops. |
Reimplement the training loop, but this time using Ray Data instead of a PyTorch DataLoader.
Key differences from the previous version:
build_data_loader_ray_train_ray_data(), which streams batches from a Ray Dataset shard (details in the following block).global_batch_size // world_size, but batches are now dictionaries with keys "image" and "label".sampler.set_epoch() or to("cuda").The rest of the loop (forward pass, loss computation, backward pass, optimizer step, metric logging, and checkpointing) stays the same.
This pattern shows how seamlessly Ray Data integrates with Ray Train, replacing DataLoader while keeping the training logic identical.
# 01. Training loop using Ray Data
def train_loop_ray_train_ray_data(config: dict):
# Same as before: define loss, model, optimizer
criterion = CrossEntropyLoss()
model = load_model_ray_train()
optimizer = Adam(model.parameters(), lr=1e-3)
# Different: build data loader from Ray Data instead of PyTorch DataLoader
global_batch_size = config["global_batch_size"]
batch_size = global_batch_size // ray.train.get_context().get_world_size()
data_loader = build_data_loader_ray_train_ray_data(batch_size=batch_size)
# Same: loop over epochs
for epoch in range(config["num_epochs"]):
# Different: no sampler.set_epoch(), Ray Data handles shuffling internally
# Different: batches are dicts {"image": ..., "label": ...} not tuples
for batch in data_loader:
outputs = model(batch["image"])
loss = criterion(outputs, batch["label"])
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Same: report metrics and save checkpoint each epoch
metrics = print_metrics_ray_train(loss, epoch)
save_checkpoint_and_metrics_ray_train(model, metrics)
Instead of using PyTorch’s DataLoader, you can build a loader from a Ray Dataset shard.
ray.train.get_dataset_shard("train") → retrieves the shard of the training dataset assigned to the current worker..iter_torch_batches() → streams the shard as PyTorch-compatible batches.
{"image": tensor, "label": tensor}).batch_size and prefetch_batches for performance tuning.This integration ensures that data is sharded, shuffled, and moved to the right device automatically, while still looking and feeling like a familiar PyTorch data loader.
Note: Use iter_torch_batches to build a PyTorch-compatible data loader from a Ray Dataset.
# 02. Build a Ray Data–backed data loader
def build_data_loader_ray_train_ray_data(batch_size: int, prefetch_batches: int = 2):
# Different: instead of creating a PyTorch DataLoader,
# fetch the training dataset shard for this worker
dataset_iterator = ray.train.get_dataset_shard("train")
# Convert the shard into a PyTorch-style iterator
# - Returns dict batches: {"image": ..., "label": ...}
# - prefetch_batches controls pipeline buffering
data_loader = dataset_iterator.iter_torch_batches(
batch_size=batch_size, prefetch_batches=prefetch_batches
)
return data_loader
Ray Data works best with data in tabular formats such as Parquet.
In this step:
"image" → raw image arrays"label" → digit class (0–9)/mnt/cluster_storage/.Parquet is efficient for both reading and distributed processing, making it a good fit for Ray Data pipelines.
# 03. Convert MNIST dataset into Parquet for Ray Data
# Build a DataFrame with image arrays and labels
df = pd.DataFrame({
"image": dataset.data.tolist(), # raw image pixels (as lists)
"label": dataset.targets # digit labels 0–9
})
# Persist the dataset in Parquet format (columnar, efficient for Ray Data)
df.to_parquet("/mnt/cluster_storage/MNIST.parquet")
Now that the training data is stored as Parquet, you can load it back into a Ray Dataset:
ray.data.read_parquet() to create a distributed Ray Dataset from the Parquet file."image" (raw pixel array) and "label" (digit class).This Ray Dataset can be passed to the TorchTrainer for distributed training.
# 04. Load the Parquet dataset into a Ray Dataset
# Read the Parquet file → creates a distributed Ray Dataset
train_ds = ray.data.read_parquet("/mnt/cluster_storage/MNIST.parquet")
To make the dataset usable by PyTorch, preprocess the raw image arrays with the same steps that PyTorch data loader does.
transform_images(row) that:
"image" array from numpy into a PIL image.ToTensor() → converts the image to a tensor.Normalize((0.5,), (0.5,)) → scales pixel values to the range [-1, 1]."image" entry in the row with the transformed tensor.This function is applied in parallel across the Ray Dataset.
# 05. Define preprocessing transform for Ray Data
def transform_images(row: dict):
# Convert numpy array to a PIL image, then apply TorchVision transforms
transform = Compose([
ToTensor(), # convert to tensor
Normalize((0.5,), (0.5,)) # normalize to [-1, 1]
])
# Ensure image is in uint8 before conversion
image_arr = np.array(row["image"], dtype=np.uint8)
# Apply transforms and replace the "image" field with tensor
row["image"] = transform(Image.fromarray(image_arr))
return row
Note: Unlike the PyTorch DataLoader, the preprocessing can now occur on any node in the cluster.
The data is passed to training workers through the ray object store (a distributed in-memory object store).
<div>Apply the preprocessing function to the dataset using map():
train_ds.map(transform_images) → runs the transform_images function on every row of the dataset."image" → normalized PyTorch tensors"label" → unchanged integer labelsThis makes the dataset ready to be streamed into the training loop.
# 06. Apply the preprocessing transform across the Ray Dataset
# Run transform_images() on each row (parallelized across cluster workers)
train_ds = train_ds.map(transform_images)
TorchTrainer with Ray DataNow, connect the Ray Dataset to the training loop using the datasets parameter in TorchTrainer:
datasets={"train": train_ds} → makes the transformed dataset available to the training loop as the "train" shard.train_loop_ray_train_ray_data → the per-worker training loop that consumes Ray Data batches.train_loop_config → passes hyperparameters (num_epochs, global_batch_size).scaling_config → specifies the number of workers and GPUs to use (same as before).run_config → defines storage for checkpoints and metrics.This setup allows Ray Train to automatically shard and stream the Ray Dataset into each worker during training.
# 07. Configure TorchTrainer with Ray Data integration
# Wrap Ray Dataset in a dict → accessible as "train" inside the training loop
datasets = {"train": train_ds}
trainer = TorchTrainer(
train_loop_ray_train_ray_data, # training loop consuming Ray Data
train_loop_config={ # hyperparameters
"num_epochs": 1,
"global_batch_size": 512,
},
scaling_config=scaling_config, # number of workers + GPU/CPU resources
run_config=RunConfig(
storage_path=storage_path,
name="dist-MNIST-res18-ray-data"
), # where to store checkpoints/logs
datasets=datasets, # provide Ray Dataset shards to workers
)
Finally, call trainer.fit() to start the distributed training job.
scaling_config.train_loop_ray_train_ray_data) on every worker in parallel.With this call, you now have a fully end-to-end distributed pipeline where Ray Data handles ingestion + preprocessing and Ray Train handles multi-GPU training.
# 08. Start the distributed training job with Ray Data integration
# Launches the training loop across all workers
# - Streams preprocessed Ray Dataset batches into each worker
# - Reports metrics and checkpoints to cluster storage
trainer.fit()
In this module you learn how Ray Train handles failures and how to make your training jobs resilient with checkpointing and recovery. You see both automatic retries and manual restoration, and how to modify the training loop so it can safely resume from the latest checkpoint.
get_checkpoint() to enable checkpoint loading.FailureConfig to set retry behavior.With fault tolerance enabled, you can run long, large-scale training jobs confidently—knowing they can recover from failures without starting over.
To support fault tolerance, extend the training loop so it can resume from a previously saved checkpoint.
Key additions:
ray.train.get_checkpoint() to check if a checkpoint is available.model.pt)optimizer.pt)extra_state.pt)start_epoch so training resumes from the correct place.The rest of the loop (forward pass, backward pass, optimizer step, and metrics reporting) is the same, except it now starts from start_epoch instead of 0.
# 01. Training loop with checkpoint loading for fault tolerance
def train_loop_ray_train_with_checkpoint_loading(config: dict):
# Same setup as before: loss, model, optimizer
criterion = CrossEntropyLoss()
model = load_model_ray_train()
optimizer = Adam(model.parameters(), lr=1e-3)
# Same data loader logic as before
global_batch_size = config["global_batch_size"]
batch_size = global_batch_size // ray.train.get_context().get_world_size()
data_loader = build_data_loader_ray_train_ray_data(batch_size=batch_size)
# Default: start at epoch 0 unless a checkpoint is available
start_epoch = 0
# Attempt to load from latest checkpoint
checkpoint = ray.train.get_checkpoint()
if checkpoint:
# Continue training from a previous checkpoint
with checkpoint.as_directory() as ckpt_dir:
# Restore model + optimizer state
model_state_dict = torch.load(
os.path.join(ckpt_dir, "model.pt"),
)
# Load the model and optimizer state
model.module.load_state_dict(model_state_dict)
optimizer.load_state_dict(
torch.load(os.path.join(ckpt_dir, "optimizer.pt"))
)
# Resume from last epoch + 1
start_epoch = (
torch.load(os.path.join(ckpt_dir, "extra_state.pt"))["epoch"] + 1
)
# Same training loop as before except it starts at a parameterized start_epoch
for epoch in range(start_epoch, config["num_epochs"]):
for batch in data_loader:
outputs = model(batch["image"])
loss = criterion(outputs, batch["label"])
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Report metrics and save model + optimizer + epoch state
metrics = print_metrics_ray_train(loss, epoch)
# We now save the optimizer and epoch state in addition to the model
save_checkpoint_and_metrics_ray_train_with_extra_state(
model, metrics, optimizer, epoch
)
To support fault-tolerant recovery, extend checkpoint saving to include not just the model, but also the optimizer state and the current epoch.
model.pt → model weights (unwrap DDP with .module).optimizer.pt → optimizer state for resuming training seamlessly.extra_state.pt → stores metadata (here, the current epoch).Only the rank-0 worker writes the checkpoint to avoid duplication, but all workers still call ray.train.report() to keep the loop synchronized.
This ensures that if training is interrupted, Ray Train can restore model weights, optimizer progress, and the correct epoch before continuing.
# 02. Save checkpoint with model, optimizer, and epoch state
def save_checkpoint_and_metrics_ray_train_with_extra_state(
model: torch.nn.Module,
metrics: dict[str, float],
optimizer: torch.optim.Optimizer,
epoch: int,
) -> None:
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
checkpoint = None
# Only rank-0 worker saves files to disk
if ray.train.get_context().get_world_rank() == 0:
# Save all state required for full recovery
torch.save(
model.module.state_dict(), # unwrap DDP before saving
os.path.join(temp_checkpoint_dir, "model.pt"),
)
torch.save(
optimizer.state_dict(), # include optimizer state
os.path.join(temp_checkpoint_dir, "optimizer.pt"),
)
torch.save(
{"epoch": epoch}, # store last completed epoch
os.path.join(temp_checkpoint_dir, "extra_state.pt"),
)
# Package into a Ray checkpoint
checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
# Report metrics and attach checkpoint (only rank-0 attaches checkpoint)
ray.train.report(
metrics,
checkpoint=checkpoint,
)
FailureConfigNow that the training loop can load from checkpoints, enable automatic retries in case of worker or node failures.
FailureConfig(max_failures=3) → allows the job to retry up to three times before giving up.failure_config into RunConfig so Ray Train knows how to handle failures.This setup makes training jobs resilient to transient hardware or cluster issues without requiring manual intervention.
# 03. Configure TorchTrainer with fault-tolerance enabled
# Allow up to 3 automatic retries if workers fail
failure_config = ray.train.FailureConfig(max_failures=3)
experiment_name = "fault-tolerant-MNIST-vit"
trainer = TorchTrainer(
train_loop_per_worker=train_loop_ray_train_with_checkpoint_loading, # fault-tolerant loop
train_loop_config={ # hyperparameters
"num_epochs": 1,
"global_batch_size": 512,
},
scaling_config=scaling_config, # resource scaling as before
run_config=ray.train.RunConfig(
name="fault-tolerant-MNIST-vit",
storage_path=storage_path, # persistent checkpoint storage
failure_config=failure_config, # enable automatic retries
),
datasets=datasets, # Ray Dataset shard for each worker
)
Finally, call trainer.fit() to start the training job.
With the fault-tolerant loop and FailureConfig in place, Ray Train can:
This makes your training job robust against transient infrastructure failures.
# 04. Start the fault-tolerant training job
# Launches training with checkpointing + automatic retries enabled
# If workers fail, Ray will reload the latest checkpoint and resume
trainer.fit()
If the maximum number of retries is reached, you can still manually restore training by creating a new TorchTrainer with the same configuration:
train_loop_ray_train_with_checkpoint_loading so the loop can resume from a checkpoint.run_config (name, storage path, and failure config).Ray Train detects the latest checkpoint in the specified storage_path and resume training from that point.
# 05. Manually restore a trainer from the last checkpoint
restored_trainer = TorchTrainer(
train_loop_per_worker=train_loop_ray_train_with_checkpoint_loading, # loop supports checkpoint loading
train_loop_config={ # hyperparameters must match
"num_epochs": 1,
"global_batch_size": 512,
},
scaling_config=scaling_config, # same resource setup as before
run_config=ray.train.RunConfig(
name="fault-tolerant-MNIST-vit", # must match previous run name
storage_path=storage_path, # path where checkpoints are saved
failure_config=failure_config, # still allow retries
),
datasets=datasets, # same dataset as before
)
Calling restored_trainer.fit() continues training from the most recent checkpoint found in the specified storage path.
Result object confirms that training picked up correctly and contains metrics, checkpoints, and logs.# 06. Resume training from the last checkpoint
# Fit the restored trainer → continues from last saved epoch
# If all epochs are already complete, training ends immediately
result = restored_trainer.fit()
# Display final training results (metrics, checkpoints, etc.)
result
Finally, remove any tutorial artifacts from persistent cluster storage:
/mnt/cluster_storage/MNIST)./mnt/cluster_storage/training)./mnt/cluster_storage/MNIST.parquet).This keeps your shared storage clean and avoids leftover data or files from occupying space.
Run this only when you're sure you no longer need the data, checkpoints, or Parquet files.
# 07. Cleanup Cluster Storage
# Paths to remove → include MNIST data, training outputs, and MNIST.parquet
paths_to_delete = [
"/mnt/cluster_storage/MNIST",
"/mnt/cluster_storage/training",
"/mnt/cluster_storage/MNIST.parquet",
]
for path in paths_to_delete:
if os.path.exists(path):
# Handle directories vs. files
if os.path.isdir(path):
shutil.rmtree(path) # recursively delete directory
else:
os.remove(path) # delete single file
print(f"Deleted: {path}")
else:
print(f"Not found: {path}")
You completed a full, production-style workflow with Ray Train on Anyscale, then extended it with Ray Data, and finally added fault tolerance. Here's what you accomplished across the three modules:
TorchTrainer using ScalingConfig and RunConfig.prepare_model() and prepare_data_loader().ray.train.report(...) (rank-0 checkpointing best practice).Result object and served GPU inference with a Ray actor.iter_torch_batches() and consumed dict batches in the training loop.datasets={"train": ...}.ray.train.get_checkpoint().FailureConfig(max_failures=...) for automatic retries.RunConfig.num_workers, try multi-node clusters, or switch to FSDP with prepare_model(parallel_strategy="fsdp").ray.train.report().In the next tutorials, you can find end-to-end workload examples for using Ray Train on Anyscale (for example, recommendation systems, vision, NLP, generative models).
You only need to pick one of these workloads to work through in the course—but feel free to explore more.
With these patterns—distributed training, scalable data ingestion, and resilient recovery—you're ready to run larger, longer, and more reliable training jobs on Anyscale.