doc/source/data/working-with-pytorch.rst
.. _working_with_pytorch:
Ray Data integrates with the PyTorch ecosystem.
This guide describes how to:
Iterate over your dataset as Torch tensors for model training <iterating_pytorch>Write transformations that deal with Torch tensors <transform_pytorch>Perform batch inference with Torch models <batch_inference_pytorch>Save Datasets containing Torch tensors <saving_pytorch>Migrate from PyTorch Datasets to Ray Data <migrate_pytorch>.. _iterating_pytorch:
To iterate over batches of data in Torch format, call :meth:Dataset.iter_torch_batches() <ray.data.Dataset.iter_torch_batches>. Each batch is represented as Dict[str, torch.Tensor], with one tensor per column in the dataset.
This is useful for training Torch models with batches from your dataset. For configuration details such as providing a collate_fn for customizing the conversion, see the API reference for :meth:iter_torch_batches() <ray.data.Dataset.iter_torch_batches>.
.. testcode::
import ray
import torch
ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
for batch in ds.iter_torch_batches(batch_size=2):
print(batch)
.. testoutput:: :options: +MOCK
{'image': tensor([[[[...]]]], dtype=torch.uint8)}
...
{'image': tensor([[[[...]]]], dtype=torch.uint8)}
Integration with Ray Train
Ray Data integrates with :ref:`Ray Train <train-docs>` for easy data ingest for data parallel training, with support for PyTorch, PyTorch Lightning, or Hugging Face training.
.. testcode::
import torch
from torch import nn
import ray
from ray import train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
def train_func():
model = nn.Sequential(nn.Linear(30, 1), nn.Sigmoid())
loss_fn = torch.nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
# Datasets can be accessed in your train_func via ``get_dataset_shard``.
train_data_shard = train.get_dataset_shard("train")
for epoch_idx in range(2):
for batch in train_data_shard.iter_torch_batches(batch_size=128, dtypes=torch.float32):
features = torch.stack([batch[col_name] for col_name in batch.keys() if col_name != "target"], axis=1)
predictions = model(features)
train_loss = loss_fn(predictions, batch["target"].unsqueeze(1))
train_loss.backward()
optimizer.step()
train_dataset = ray.data.read_csv("s3://anonymous@air-example-data/breast_cancer.csv")
trainer = TorchTrainer(
train_func,
datasets={"train": train_dataset},
scaling_config=ScalingConfig(num_workers=2)
)
trainer.fit()
For more details, see the :ref:`Ray Train user guide <data-ingest-torch>`.
.. _transform_pytorch:
Transformations with Torch tensors
----------------------------------
Transformations applied with `map` or `map_batches` can return Torch tensors.
.. caution::
Under the hood, Ray Data automatically converts Torch tensors to NumPy arrays. Subsequent transformations accept NumPy arrays as input, not Torch tensors.
.. tab-set::
.. tab-item:: map
.. testcode::
from typing import Dict
import numpy as np
import torch
import ray
ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
def convert_to_torch(row: Dict[str, np.ndarray]) -> Dict[str, torch.Tensor]:
return {"tensor": torch.as_tensor(row["image"])}
# The tensor gets converted into a Numpy array under the hood
transformed_ds = ds.map(convert_to_torch)
print(transformed_ds.schema())
# Subsequent transformations take in Numpy array as input.
def check_numpy(row: Dict[str, np.ndarray]):
assert isinstance(row["tensor"], np.ndarray)
return row
transformed_ds.map(check_numpy).take_all()
.. testoutput::
Column Type
------ ----
tensor ArrowTensorTypeV2(shape=(32, 32, 3), dtype=uint8)
.. tab-item:: map_batches
.. testcode::
from typing import Dict
import numpy as np
import torch
import ray
ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
def convert_to_torch(batch: Dict[str, np.ndarray]) -> Dict[str, torch.Tensor]:
return {"tensor": torch.as_tensor(batch["image"])}
# The tensor gets converted into a Numpy array under the hood
transformed_ds = ds.map_batches(convert_to_torch, batch_size=2)
print(transformed_ds.schema())
# Subsequent transformations take in Numpy array as input.
def check_numpy(batch: Dict[str, np.ndarray]):
assert isinstance(batch["tensor"], np.ndarray)
return batch
transformed_ds.map_batches(check_numpy, batch_size=2).take_all()
.. testoutput::
Column Type
------ ----
tensor ArrowTensorTypeV2(shape=(32, 32, 3), dtype=uint8)
For more information on transforming data, see :ref:`Transforming data <transforming_data>`.
Built-in PyTorch transforms
You can use built-in Torch transforms from torchvision, torchtext, and torchaudio.
.. tab-set::
.. tab-item:: torchvision
.. testcode::
from typing import Dict
import numpy as np
import torch
from torchvision import transforms
import ray
# Create the Dataset.
ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
# Define the torchvision transform.
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.CenterCrop(10)
]
)
# Define the map function
def transform_image(row: Dict[str, np.ndarray]) -> Dict[str, torch.Tensor]:
row["transformed_image"] = transform(row["image"])
return row
# Apply the transform over the dataset.
transformed_ds = ds.map(transform_image)
print(transformed_ds.schema())
.. testoutput::
Column Type
------ ----
image ArrowTensorTypeV2(shape=(32, 32, 3), dtype=uint8)
transformed_image ArrowTensorTypeV2(shape=(3, 10, 10), dtype=float)
.. tab-item:: torchtext
.. testcode::
:skipif: True
from typing import Dict, List
import numpy as np
from torchtext import transforms
import ray
# Create the Dataset.
ds = ray.data.read_text("s3://anonymous@ray-example-data/simple.txt")
# Define the torchtext transform.
VOCAB_FILE = "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt"
transform = transforms.BERTTokenizer(vocab_path=VOCAB_FILE, do_lower_case=True, return_tokens=True)
# Define the map_batches function.
def tokenize_text(batch: Dict[str, np.ndarray]) -> Dict[str, List[str]]:
batch["tokenized_text"] = transform(list(batch["text"]))
return batch
# Apply the transform over the dataset.
transformed_ds = ds.map_batches(tokenize_text, batch_size=2)
print(transformed_ds.schema())
.. testoutput::
Column Type
------ ----
text string
tokenized_text list<item: string>
.. _batch_inference_pytorch:
With Ray Datasets, you can do scalable offline batch inference with Torch models by mapping a pre-trained model over your data.
.. testcode::
from typing import Dict
import numpy as np
import torch
import torch.nn as nn
import ray
# Step 1: Create a Ray Dataset from in-memory Numpy arrays.
# You can also create a Ray Dataset from many other sources and file
# formats.
ds = ray.data.from_numpy(np.ones((1, 100)))
# Step 2: Define a Predictor class for inference.
# Use a class to initialize the model just once in `__init__`
# and reuse it for inference across multiple batches.
class TorchPredictor:
def __init__(self):
# Load a dummy neural network.
# Set `self.model` to your pre-trained PyTorch model.
self.model = nn.Sequential(
nn.Linear(in_features=100, out_features=1),
nn.Sigmoid(),
)
self.model.eval()
# Logic for inference on 1 batch of data.
def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
tensor = torch.as_tensor(batch["data"], dtype=torch.float32)
with torch.inference_mode():
# Get the predictions from the input batch.
return {"output": self.model(tensor).numpy()}
# Step 3: Map the Predictor over the Dataset to get predictions.
# Use 2 parallel actors for inference. Each actor predicts on a
# different partition of data.
predictions = ds.map_batches(TorchPredictor, compute=ray.data.ActorPoolStrategy(size=2))
# Step 4: Show one prediction output.
predictions.show(limit=1)
.. testoutput:: :options: +MOCK
{'output': array([0.5590901], dtype=float32)}
For more details, see the :ref:Batch inference user guide <batch_inference_home>.
.. _saving_pytorch:
Datasets containing Torch tensors can be saved to files, like parquet or NumPy.
For more information on saving data, read
:ref:Saving data <saving-data>.
.. caution::
Torch tensors that are on GPU devices can't be serialized and written to disk. Convert the tensors to CPU (``tensor.to("cpu")``) before saving the data.
.. tab-set::
.. tab-item:: Parquet
.. testcode::
import torch
import ray
tensor = torch.Tensor(1)
ds = ray.data.from_items([{"tensor": tensor}])
ds.write_parquet("local:///tmp/tensor")
.. tab-item:: Numpy
.. testcode::
import torch
import ray
tensor = torch.Tensor(1)
ds = ray.data.from_items([{"tensor": tensor}])
ds.write_numpy("local:///tmp/tensor", column="tensor")
.. _migrate_pytorch:
If you're currently using PyTorch Datasets and DataLoaders, you can migrate to Ray Data for working with distributed datasets.
PyTorch Datasets are replaced by the :class:Dataset <ray.data.Dataset> abstraction, and the PyTorch DataLoader is replaced by :meth:Dataset.iter_torch_batches() <ray.data.Dataset.iter_torch_batches>.
Built-in PyTorch Datasets
If you are using built-in PyTorch datasets, for example from ``torchvision``, these can be converted to a Ray Dataset using the :meth:`from_torch() <ray.data.from_torch>` API.
.. testcode::
:skipif: True
import torchvision
import ray
mnist = torchvision.datasets.MNIST(root="/tmp/", download=True)
ds = ray.data.from_torch(mnist)
# The data for each item of the Torch dataset is under the "item" key.
print(ds.schema())
..
The following `testoutput` is mocked to avoid illustrating download logs like
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz".
.. testoutput::
:options: +MOCK
Column Type
------ ----
item <class 'object'>
Custom PyTorch Datasets
~~~~~~~~~~~~~~~~~~~~~~~
If you have a custom PyTorch Dataset, you can migrate to Ray Data by converting the logic in ``__getitem__`` to Ray Data read and transform operations.
Any logic for reading data from cloud storage and disk can be replaced by one of the Ray Data ``read_*`` APIs, and any transformation logic can be applied as a :meth:`map <ray.data.Dataset.map>` call on the Dataset.
The following example shows a custom PyTorch Dataset, and what the analogous would look like with Ray Data.
.. note::
Unlike PyTorch Map-style datasets, Ray Datasets aren't indexable.
.. tab-set::
.. tab-item:: PyTorch Dataset
.. testcode::
import tempfile
import boto3
from botocore import UNSIGNED
from botocore.config import Config
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
class ImageDataset(Dataset):
def __init__(self, bucket_name: str, dir_path: str):
self.s3 = boto3.resource("s3", config=Config(signature_version=UNSIGNED))
self.bucket = self.s3.Bucket(bucket_name)
self.files = [obj.key for obj in self.bucket.objects.filter(Prefix=dir_path)]
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((128, 128)),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
img_name = self.files[idx]
# Infer the label from the file name.
last_slash_idx = img_name.rfind("/")
dot_idx = img_name.rfind(".")
label = int(img_name[last_slash_idx+1:dot_idx])
# Download the S3 file locally.
obj = self.bucket.Object(img_name)
tmp = tempfile.NamedTemporaryFile()
tmp_name = "{}.jpg".format(tmp.name)
with open(tmp_name, "wb") as f:
obj.download_fileobj(f)
f.flush()
f.close()
image = Image.open(tmp_name)
# Preprocess the image.
image = self.transform(image)
return image, label
dataset = ImageDataset(bucket_name="ray-example-data", dir_path="batoidea/JPEGImages/")
.. tab-item:: Ray Data
.. testcode::
import torchvision
import ray
ds = ray.data.read_images("s3://anonymous@ray-example-data/batoidea/JPEGImages", include_paths=True)
# Extract the label from the file path.
def extract_label(row: dict):
filepath = row["path"]
last_slash_idx = filepath.rfind("/")
dot_idx = filepath.rfind('.')
label = int(filepath[last_slash_idx+1:dot_idx])
row["label"] = label
return row
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((128, 128)),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Preprocess the images.
def transform_image(row: dict):
row["transformed_image"] = transform(row["image"])
return row
# Map the transformations over the dataset.
ds = ds.map(extract_label).map(transform_image)
PyTorch DataLoader
~~~~~~~~~~~~~~~~~~
The PyTorch DataLoader can be replaced by calling :meth:`Dataset.iter_torch_batches() <ray.data.Dataset.iter_torch_batches>` to iterate over batches of the dataset.
The following table describes how the arguments for PyTorch DataLoader map to Ray Data. Note the behavior may not necessarily be identical. For exact semantics and usage, see the API reference for :meth:`iter_torch_batches() <ray.data.Dataset.iter_torch_batches>`.
.. list-table::
:header-rows: 1
* - PyTorch DataLoader arguments
- Ray Data API
* - ``batch_size``
- ``batch_size`` argument to :meth:`ds.iter_torch_batches() <ray.data.Dataset.iter_torch_batches>`
* - ``shuffle``
- ``local_shuffle_buffer_size`` argument to :meth:`ds.iter_torch_batches() <ray.data.Dataset.iter_torch_batches>`
* - ``collate_fn``
- ``collate_fn`` argument to :meth:`ds.iter_torch_batches() <ray.data.Dataset.iter_torch_batches>`
* - ``sampler``
- Not supported. Can be manually implemented after iterating through the dataset with :meth:`ds.iter_torch_batches() <ray.data.Dataset.iter_torch_batches>`.
* - ``batch_sampler``
- Not supported. Can be manually implemented after iterating through the dataset with :meth:`ds.iter_torch_batches() <ray.data.Dataset.iter_torch_batches>`.
* - ``drop_last``
- ``drop_last`` argument to :meth:`ds.iter_torch_batches() <ray.data.Dataset.iter_torch_batches>`
* - ``num_workers``
- Use ``prefetch_batches`` argument to :meth:`ds.iter_torch_batches() <ray.data.Dataset.iter_torch_batches>` to indicate how many batches to prefetch. The number of prefetching threads are automatically configured according to ``prefetch_batches``.
* - ``prefetch_factor``
- Use ``prefetch_batches`` argument to :meth:`ds.iter_torch_batches() <ray.data.Dataset.iter_torch_batches>` to indicate how many batches to prefetch. The number of prefetching threads are automatically configured according to ``prefetch_batches``.
* - ``pin_memory``
- Pass in ``device`` to :meth:`ds.iter_torch_batches() <ray.data.Dataset.iter_torch_batches>` to get tensors that have already been moved to the correct device.