benchmarking/ai/image_classification/spark.ipynb
import io
import numpy as np
import pandas as pd
import torch
from PIL import Image
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights
from pyspark.sql.functions import col, pandas_udf
from pyspark.sql.types import StringType, ArrayType, FloatType
%%configure -f
{
"conf": {
"spark.sql.execution.arrow.maxRecordsPerBatch": "100"
}
}
_model_cache = {"model": None, "weights": None, "device": None}
def get_model():
import os
os.environ["TORCH_HOME"] = "/tmp/torch"
os.environ["XDG_CACHE_HOME"] = "/tmp"
if _model_cache["model"] is None:
device = "cuda"
weights = ResNet18_Weights.DEFAULT
model = resnet18(weights=weights).eval().to(device)
_model_cache["model"] = model
_model_cache["weights"] = weights
_model_cache["device"] = device
return _model_cache["model"], _model_cache["weights"], _model_cache["device"]
transform = transforms.Compose([
transforms.ToTensor(),
ResNet18_Weights.DEFAULT.transforms()
])
@pandas_udf(ArrayType(FloatType()))
def decode_and_preprocess_image_udf(image_data_series: pd.Series) -> pd.Series:
decoded_images = []
for image_data in image_data_series:
if image_data is None:
decoded_images.append(None)
continue
try:
image = np.array(Image.open(io.BytesIO(image_data)).convert("RGB"))
if len(image.shape) != 3:
raise ValueError(f"Invalid image shape: {image.shape}")
decoded_images.append(transform(image).flatten().tolist())
except Exception as e:
print(f"Error decoding image: {e}")
decoded_images.append(None)
return pd.Series(decoded_images)
@pandas_udf(StringType())
def predict_batch_udf(norm_images: pd.Series):
model, weights, device = get_model()
try:
np_batch = np.vstack(norm_images.tolist())
np_batch_reshaped = np_batch.reshape(-1, 3, 224, 224).astype(np.float32)
except ValueError as e:
print(f"Error reshaping tensor: {e}")
return pd.Series([None] * len(norm_images))
torch_batch = torch.from_numpy(np_batch_reshaped).to(device)
with torch.inference_mode():
prediction = model(torch_batch)
predicted_classes = prediction.argmax(dim=1).detach().cpu()
predicted_labels = [
weights.meta["categories"][i] for i in predicted_classes
]
return pd.Series(predicted_labels)
paths = spark.read.parquet("s3://daft-public-datasets/imagenet/benchmark").collect()
paths = [row.image_url for row in paths]
df = spark.read.format("binaryFile").load(paths)
df = (
df.withColumn("processed_image", decode_and_preprocess_image_udf(col("content")))
.filter(col("processed_image").isNotNull())
.withColumn("label", predict_batch_udf(col("processed_image")))
.select("path", "label")
)
df.write.mode("append").parquet(
"s3://eventual-dev-benchmarking-results/ai-benchmark-results/image-classification-results"
)