benchmarking/ai/video_object_detection/spark.ipynb
from __future__ import annotations
import io
from typing import Any, Dict, List
import av
import torch
import torchvision
from PIL import Image
from pyspark.sql.functions import col, explode, pandas_udf
from pyspark.sql.types import (
ArrayType,
BinaryType,
FloatType,
IntegerType,
StringType,
StructField,
StructType,
)
from ultralytics import YOLO
%%configure -f
{
"executorCores": 1,
"conf": {
"spark.sql.execution.arrow.maxRecordsPerBatch": "64"
}
}
_model_cache = {"model": None}
def get_model():
if _model_cache["model"] is None:
model = YOLO("yolo11n.pt")
if torch.cuda.is_available():
model.to("cuda")
_model_cache["model"] = model
return _model_cache["model"]
@pandas_udf(ArrayType(BinaryType()))
def decode_video_udf(video_bytes_iter):
import pandas as pd
results = []
for video_bytes in video_bytes_iter:
frames = []
with av.open(io.BytesIO(video_bytes)) as container:
for frame in container.decode(video=0):
img = frame.to_ndarray(format="rgb24")
pil_img = Image.fromarray(img).resize((640, 640))
buf = io.BytesIO()
pil_img.save(buf, format="PNG")
frames.append(buf.getvalue())
results.append(frames)
return pd.Series(results)
feature_schema = ArrayType(
StructType(
[
StructField("label", StringType(), False),
StructField("confidence", FloatType(), False),
StructField("bbox", ArrayType(IntegerType()), False),
]
)
)
def to_features(res: Any) -> List[Dict[str, Any]]:
return [
{
"label": res.names[int(cls_id)],
"confidence": float(conf),
"bbox": [int(v) for v in bbox.tolist()],
}
for cls_id, conf, bbox in zip(res.boxes.cls, res.boxes.conf, res.boxes.xyxy)
]
@pandas_udf(feature_schema)
def extract_image_features_udf(images):
import pandas as pd
if len(images) == 0:
return []
model = get_model()
tensors = [torchvision.transforms.functional.to_tensor(Image.open(io.BytesIO(img))) for img in images]
stack = torch.stack(tensors, dim=0)
results = model(stack)
return pd.Series([to_features(r) for r in results])
@pandas_udf(BinaryType())
def crop_udf(frame_bytes_iter, bbox_iter):
import pandas as pd
outputs = []
for frame_bytes, bbox in zip(frame_bytes_iter, bbox_iter):
try:
img = Image.open(io.BytesIO(frame_bytes)).convert("RGB")
x1, y1, x2, y2 = bbox
cropped = img.crop((x1, y1, x2, y2))
buf = io.BytesIO()
cropped.save(buf, format="PNG")
outputs.append(buf.getvalue())
except Exception:
outputs.append(None)
return pd.Series(outputs)
df = spark.read.format("binaryFile").load("s3://daft-public-data/videos/Hollywood2-actions-videos/Hollywood2/AVIClips/")
df = df.withColumn("frame", decode_video_udf(col("content")))
df = df.withColumn("frame", explode(col("frame")))
df = df.checkpoint()
df = df.withColumn("features", extract_image_features_udf(col("frame")))
df = df.withColumn("feature", explode(col("features")))
df = df.checkpoint()
df = df.withColumn("object", crop_udf(col("frame"), col("feature.bbox")))
df = df.drop("content", "frame")
df.write.mode("append").parquet("s3://eventual-dev-benchmarking-results/ai-benchmark-results/video-object-detection-result")