Back to Daft

Spark

benchmarking/ai/image_classification/spark.ipynb

0.7.103.1 KB
Original Source
python
import io

import numpy as np
import pandas as pd
import torch
from PIL import Image
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights
from pyspark.sql.functions import col, pandas_udf
from pyspark.sql.types import StringType, ArrayType, FloatType
python
%%configure -f
{
  "conf": {
    "spark.sql.execution.arrow.maxRecordsPerBatch": "100"
  }
}
python
_model_cache = {"model": None, "weights": None, "device": None}


def get_model():
    import os

    os.environ["TORCH_HOME"] = "/tmp/torch"
    os.environ["XDG_CACHE_HOME"] = "/tmp"

    if _model_cache["model"] is None:
        device = "cuda"
        weights = ResNet18_Weights.DEFAULT
        model = resnet18(weights=weights).eval().to(device)
        _model_cache["model"] = model
        _model_cache["weights"] = weights
        _model_cache["device"] = device
    return _model_cache["model"], _model_cache["weights"], _model_cache["device"]
python
transform = transforms.Compose([
    transforms.ToTensor(),
    ResNet18_Weights.DEFAULT.transforms()
])

@pandas_udf(ArrayType(FloatType()))
def decode_and_preprocess_image_udf(image_data_series: pd.Series) -> pd.Series:
    decoded_images = []
    
    for image_data in image_data_series:
        if image_data is None:
            decoded_images.append(None)
            continue
            
        try:
            image = np.array(Image.open(io.BytesIO(image_data)).convert("RGB"))
            
            if len(image.shape) != 3:
                raise ValueError(f"Invalid image shape: {image.shape}")
            
            decoded_images.append(transform(image).flatten().tolist())
            
        except Exception as e:
            print(f"Error decoding image: {e}")
            decoded_images.append(None)
    
    return pd.Series(decoded_images)

@pandas_udf(StringType())
def predict_batch_udf(norm_images: pd.Series):
    model, weights, device = get_model()
    try:
        np_batch = np.vstack(norm_images.tolist())
        np_batch_reshaped = np_batch.reshape(-1, 3, 224, 224).astype(np.float32)
    except ValueError as e:
        print(f"Error reshaping tensor: {e}")
        return pd.Series([None] * len(norm_images))

    torch_batch = torch.from_numpy(np_batch_reshaped).to(device)
    with torch.inference_mode():
        prediction = model(torch_batch)
        predicted_classes = prediction.argmax(dim=1).detach().cpu()
        predicted_labels = [
            weights.meta["categories"][i] for i in predicted_classes
        ]
    
    return pd.Series(predicted_labels)
python
paths = spark.read.parquet("s3://daft-public-datasets/imagenet/benchmark").collect()
paths = [row.image_url for row in paths]

df = spark.read.format("binaryFile").load(paths)
df = (
    df.withColumn("processed_image", decode_and_preprocess_image_udf(col("content")))
    .filter(col("processed_image").isNotNull())
    .withColumn("label", predict_batch_udf(col("processed_image")))
    .select("path", "label")
)

df.write.mode("append").parquet(
    "s3://eventual-dev-benchmarking-results/ai-benchmark-results/image-classification-results"
)