Back to Tensorzero

Example: Optimizing Data Extraction (NER) with TensorZero

examples/data-extraction-ner/data-extraction-ner.ipynb

2026.6.08.4 KB
Original Source

Example: Optimizing Data Extraction (NER) with TensorZero

Setup

python
import asyncio
import json
from collections import Counter
from typing import Dict, List, Optional

import altair as alt
import pandas as pd
from tensorzero import AsyncTensorZeroGateway, InferenceResponse
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio

IMPORTANT: Update the gateway URL below if you're not using the standard setup provided in this example

python
TENSORZERO_GATEWAY_URL = "http://localhost:3000"

Load the Dataset

python
# Select only a subset of the dataset to speed things up
NUM_TRAIN_DATAPOINTS = 500
NUM_VAL_DATAPOINTS = 500
python
def load_dataset(path: str) -> (pd.DataFrame, pd.DataFrame):
    # Load the dataset
    df = pd.read_csv(path)
    df.output = df.output.apply(json.loads)

    # Split the dataset into train and validation sets
    train_df = df[df["split"] == 0]
    val_df = df[df["split"] == 1]

    # Shuffle the splits
    train_df = train_df.sample(frac=1, random_state=0).reset_index(drop=True)
    val_df = val_df.sample(frac=1, random_state=0).reset_index(drop=True)

    # Select only a subset of the dataset to speed things up
    train_df = train_df.iloc[:NUM_TRAIN_DATAPOINTS]
    val_df = val_df.iloc[:NUM_VAL_DATAPOINTS]

    return train_df, val_df
python
train_df, val_df = load_dataset("data/conllpp.csv")

print(f"Train data shape: {train_df.shape}")
print(f"Validation data shape: {val_df.shape}")

Extract Entities

IMPORTANT: Reduce the number of concurrent requests if you're running into rate limits

python
MAX_CONCURRENT_REQUESTS = 10
python
tensorzero_client = await AsyncTensorZeroGateway.build_http(gateway_url=TENSORZERO_GATEWAY_URL, timeout=15)
semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
python
async def get_entities(
    text: str,
    variant_name: Optional[str] = None,
    dryrun: bool = False,
) -> Optional[InferenceResponse]:
    # Use a semaphore to avoid rate limits
    async with semaphore:
        try:
            return await tensorzero_client.inference(
                function_name="extract_entities",
                input={"messages": [{"role": "user", "content": text}]},
                dryrun=dryrun,
                variant_name=variant_name,
            )
        except Exception as e:
            print(f"Error occurred: {type(e).__name__}: {e}")
            return None
python
# Run inference in parallel to speed things up
responses = await tqdm_asyncio.gather(*[get_entities(text) for text in train_df["input"]])

Evaluate the Performance

python
def flatten_dict(d: Dict[str, List[str]]) -> List[str]:
    res = []
    for k, v in d.items():
        assert isinstance(v, list)
        for elt in v:
            res.append(f"__{k.upper()}__::{elt}")
    return res


# Exact match between the predicted and ground truth entities (the sharpest metric we use to evaluate NER)
def compute_exact_match(predicted: Dict[str, List[str]], ground_truth: Dict[str, List[str]]) -> bool:
    return set(flatten_dict(predicted)) == set(flatten_dict(ground_truth))


# Jaccard similarity between the predicted and ground_truth entities
# (a more lenient metric that gives partial credit for correct entities)
# This is a different implementation from the original code by Predibase, so the metrics won't be directly comparable.
def compute_jaccard_similarity(predicted: Dict[str, List[str]], ground_truth: Dict[str, List[str]]) -> float:
    target_entities = flatten_dict(ground_truth)
    pred_entities = flatten_dict(predicted)
    target_count = Counter(target_entities)
    pred_count = Counter(pred_entities)
    num = 0
    den = 0
    all_keys = set(target_entities).union(set(pred_entities))
    for key in all_keys:
        num += min(target_count.get(key, 0), pred_count.get(key, 0))
        den += max(target_count.get(key, 0), pred_count.get(key, 0))
    if den == 0:
        return 1
    return num / den
python
def evaluate_response(response: Optional[InferenceResponse], ground_truth_data: Dict[str, List[str]]):
    predicted = response.output.parsed if response else None

    # `predicted` is None if the model failed to return a valid JSON that complies with the output schema
    valid_output = predicted is not None

    # Compute the other metrics
    exact_match = compute_exact_match(predicted, ground_truth_data) if predicted else False
    jaccard_similarity = compute_jaccard_similarity(predicted, ground_truth_data) if predicted else 0

    return valid_output, exact_match, jaccard_similarity
python
for response, ground_truth in tqdm(zip(responses, train_df["output"]), total=len(responses)):
    # Don't send feedback if the request failed completely
    if response is None:
        continue

    # Evaluate the example
    valid_output, exact_match, jaccard_similarity = evaluate_response(response, ground_truth)

    # Send the metrics feedback to TensorZero
    await tensorzero_client.feedback(
        metric_name="valid_output",
        value=valid_output,
        inference_id=response.inference_id,
    )

    await tensorzero_client.feedback(
        metric_name="exact_match",
        value=exact_match,
        inference_id=response.inference_id,
    )

    await tensorzero_client.feedback(
        metric_name="jaccard_similarity",
        value=jaccard_similarity,
        inference_id=response.inference_id,
    )

    # Send the demonstration feedback to TensorZero
    await tensorzero_client.feedback(
        metric_name="demonstration",
        value=ground_truth,
        inference_id=response.inference_id,
    )

Validation Set

IMPORTANT: Update the list below when you create new variants in tensorzero.toml

python
# Include the variants in `tensorzero.toml` that we want to evaluate
VARIANTS_TO_EVALUATE = [
    "gpt_4o",
    "gpt_4o_mini",
    # "gpt_4o_mini_fine_tuned",
]
python
scores = {}  # variant_name => (valid_output, exact_match, jaccard_similarity)

for variant_name in VARIANTS_TO_EVALUATE:
    # Run inference on the validation set
    responses = await tqdm_asyncio.gather(
        *[
            get_entities(
                text,
                variant_name=variant_name,  # pin to the specific variant we want to evaluate
                dryrun=True,  # don't store results to avoid leaking data
            )
            for text in val_df["input"]
        ],
        desc=f"Evaluating variant: {variant_name}",
    )

    # Evaluate the performance of the variant
    valid_output_scores = []
    exact_match_scores = []
    jaccard_similarity_scores = []

    for response, ground_truth in zip(responses, val_df["output"]):
        valid_output, exact_match, jaccard_similarity = evaluate_response(response, ground_truth)
        valid_output_scores.append(valid_output)
        exact_match_scores.append(exact_match)
        jaccard_similarity_scores.append(jaccard_similarity)

    scores[variant_name] = {
        "valid_output": valid_output_scores,
        "exact_match": exact_match_scores,
        "jaccard_similarity": jaccard_similarity_scores,
    }

    # Print the performance of the variant
    print(f"Valid Output: {sum(valid_output_scores) / len(valid_output_scores):.1%}")
    print(f"Exact Match: {sum(exact_match_scores) / len(exact_match_scores):.1%}")
    print(f"Jaccard Similarity (mean): {sum(jaccard_similarity_scores) / len(jaccard_similarity_scores):.1%}")
    print()

Plot Results

python
scores_df = []

for variant_name, variant_scores in scores.items():
    exact_match_score = sum(variant_scores["exact_match"]) / len(variant_scores["exact_match"])
    scores_df.append(
        {
            "Variant": variant_name,
            "Metric": "exact_match",
            "Score": exact_match_score,
        }
    )

    jaccard_similarity_score = sum(variant_scores["jaccard_similarity"]) / len(variant_scores["jaccard_similarity"])

    scores_df.append(
        {
            "Variant": variant_name,
            "Metric": "jaccard_similarity",
            "Score": jaccard_similarity_score,
        }
    )

scores_df = pd.DataFrame(scores_df)
python
chart = (
    alt.Chart(scores_df)
    .encode(
        x=alt.X("Score:Q", axis=alt.Axis(format="%"), scale=alt.Scale(domain=[0, 1])),
        y="Variant:N",
        yOffset="Metric:N",
        color="Metric:N",
        text=alt.Text("Score:Q", format=".1%"),
    )
    .properties(title="Metrics by Variant")
)

chart = chart.mark_bar() + chart.mark_text(align="left", dx=2)

chart