Back to Tensorzero

Example: Multimodal (Vision) Finetuning

examples/multimodal-vision-finetuning/main.ipynb

2026.4.13.5 KB
Original Source

Example: Multimodal (Vision) Finetuning

python
import asyncio
import base64
import os
from os import PathLike
from pathlib import Path

import numpy as np
import pandas as pd
from tensorzero import AsyncTensorZeroGateway
from tqdm.asyncio import tqdm_asyncio
python
DATA_PATH = Path("data")

LABELS_PATH = DATA_PATH / "labels.csv"

CONCURRENCY = 10

VARIANT_NAME = "baseline"
python
def load_data(path: PathLike):
    assert LABELS_PATH.exists(), (
        f"Labels file {LABELS_PATH} does not exist. See the README.md and ensure you've downloaded the dataset correctly."
    )

    df = pd.read_csv(LABELS_PATH)

    # Sanity Check: ensure every image exists
    for _, row in df.iterrows():
        img_path = path / Path(row["document"])
        assert img_path.exists(), (
            f"Image {img_path} does not exist. See the README.md and ensure you've downloaded the dataset correctly."
        )

    train_df = df[df["is_train"] == 1].reset_index(drop=True)
    test_df = df[df["is_train"] == 0].reset_index(drop=True)

    return train_df, test_df


train_df, test_df = load_data(DATA_PATH)

print(f"Found {len(train_df)} train documents and {len(test_df)} test documents")
python
train_df.sample(5)
python
test_df.sample(5)
python
os.makedirs("tensorzero/object_storage", exist_ok=True)
python
t0 = await AsyncTensorZeroGateway.build_http(
    gateway_url="http://localhost:3000",
)
python
def load_document(path: PathLike):
    """Load an image and encode as a base64 string"""
    path = DATA_PATH / path
    assert path.exists()
    assert path.suffix.lower() == ".png"

    with open(path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")
python
async def process_document(row):
    response = await t0.inference(
        function_name="classify_document",
        input={
            "system": "Categorize this document.",
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "mime_type": "image/png",
                            "data": load_document(row.document),
                        },
                    ],
                }
            ],
        },
        dryrun=not row.is_train,
        cache_options={
            "enabled": "on",
        },
        variant_name=VARIANT_NAME,
    )

    predicted_category = response.output.parsed["category"]
    correct_classification = predicted_category == row.label

    if row.is_train:
        await t0.feedback(
            metric_name="correct_classification",
            value=correct_classification,
            inference_id=response.inference_id,
        )

        await t0.feedback(
            metric_name="demonstration",
            value={
                "category": row.label,
            },
            inference_id=response.inference_id,
        )

    return correct_classification
python
semaphore = asyncio.Semaphore(CONCURRENCY)


async def process_document_with_semaphore(row):
    async with semaphore:
        return await process_document(row)
python
scores = await tqdm_asyncio.gather(*[process_document_with_semaphore(row) for _, row in train_df.iterrows()])

print(f"Train Set Accuracy: {np.mean(scores):.1%}")
python
scores = await tqdm_asyncio.gather(*[process_document_with_semaphore(row) for _, row in test_df.iterrows()])

print(f"Test Set Accuracy: {np.mean(scores):.1%}")