Back to Daft

Distributed ML model batch inference on data in DeltaLake

tutorials/delta_lake/2-distributed-batch-inference.ipynb

0.7.104.9 KB
Original Source

Distributed ML model batch inference on data in DeltaLake

In this tutorial, we showcase how to perform ML model batch inference on data in a DeltaLake table.

This is a continuation of the previous tutorial on local batch inference, which is a great way to get started and make sure that your code is working before graduating to larger scales in a distributed batch inference workload. Make sure to give that a read before looking at this tutorial!

To run this tutorial you will require AWS credentials to be correctly provisioned on your machine as all data is hosted in a requestor-pays bucket in AWS S3.

Let's get started!

python
CI = False
python
# Skip this notebook execution in CI because it hits non-public buckets
if CI:
    import sys

    sys.exit()

Going Distributed

The first step (and most important for this demo!) is to switch our Daft runner to the Ray runner, and point it at a Ray cluster. This is super simple:

python
import daft

# If you have your own Ray cluster running, feel free to set this to that address!
# RAY_ADDRESS = "ray://localhost:10001"
RAY_ADDRESS = None

daft.set_runner_ray(address=RAY_ADDRESS)

Now, we run the same operations as before. The only difference is that instead of execution happening locally on the machine that's running this code, Daft will distribute the computation over your Ray cluster!

python
# Feel free to tweak this variable to have the tutorial run on as many rows as you'd like!
NUM_ROWS = 1000

Retrieving data

We will be retrieving the data exactly the same way we did in the previous tutorial, with the same API and arguments.

python
# Provision Cloud Credentials
import boto3

import daft

session = boto3.session.Session()
creds = session.get_credentials()
io_config = daft.io.IOConfig(
    s3=daft.io.S3Config(
        access_key=creds.secret_key,
        key_id=creds.access_key,
        session_token=creds.token,
        region_name="us-west-2",
    )
)

# Retrieve data
df = daft.read_deltalake("s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/", io_config=io_config)

# Prune data
df = df.limit(NUM_ROWS)
df = df.where(df["object"].list.length() == 1)

Splitting the data into more partitions

We now split the data into more partitions for additional parallelism when performing our data processing in a distributed fashion

python
df = df.into_partitions(16)

Retrieving the images and preprocessing

Now we continue with exactly the same code as in the local case for retrieving and preprocessing our images

python
# Retrieve images and run preprocessing
df = df.with_column(
    "image_url", "s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/" + df["filename"] + ".jpeg"
)
df = df.with_column("image", df["image_url"].download().decode_image())
df = df.with_column("image_resized_small", df["image"].resize(32, 32))
df = df.with_column("image_resized_large", df["image"].resize(256, 256))

Running batch inference with a UDF

Running the UDF is also exactly the same!

python
# Run batch inference over the entire dataset
import numpy as np
import torch
from torchvision.models import ResNet50_Weights, resnet50

import daft


@daft.udf(return_dtype=daft.DataType.string())
class ClassifyImage:
    def __init__(self):
        weights = ResNet50_Weights.DEFAULT
        self.model = resnet50(weights=weights)
        self.model.eval()
        self.preprocess = weights.transforms()
        self.category_map = weights.meta["categories"]

    def __call__(self, images: daft.Series, shape: list[int, int, int]):
        if len(images) == 0:
            return []

        # Convert the Daft Series into a list of Numpy arrays
        data = images.cast(daft.DataType.tensor(daft.DataType.uint8(), tuple(shape))).to_pylist()

        # Convert the numpy arrays into a torch tensor
        images_array = torch.tensor(np.array(data)).permute((0, 3, 1, 2))

        # Run the model, and map results back to a human-readable string
        batch = self.preprocess(images_array)
        prediction = self.model(batch).softmax(0)
        class_ids = prediction.argmax(1)
        prediction[:, class_ids]
        return [self.category_map[class_id] for class_id in class_ids]


# Filter out rows where the channel != 3
df = df.where(df["image"].apply(lambda img: img.shape[2] == 3, return_dtype=daft.DataType.bool()))

df = df.with_column("predictions_lowres", ClassifyImage(df["image_resized_small"], [32, 32, 3]))
df = df.with_column("predictions_highres", ClassifyImage(df["image_resized_large"], [256, 256, 3]))

# Prune the results and write data back out as Parquet
df = df.select(
    "filename",
    "image_url",
    "object",
    "predictions_lowres",
    "predictions_highres",
)
df.write_parquet("my_results.parquet")

Now, take a look at your handiwork!

Let's read the results of our distributed Daft job!

python
daft.read_parquet("my_results.parquet").collect()