examples/haiku-hidden-preferences/haiku.ipynb
import asyncio
import random
import altair as alt
import pandas as pd
from tensorzero import AsyncTensorZeroGateway
from tqdm.asyncio import tqdm_asyncio
IMPORTANT: Update the gateway URL below if you're not using the standard setup provided in this example
TENSORZERO_GATEWAY_URL = "http://localhost:3000"
NUM_TRAIN_DATAPOINTS = 500
NUM_VAL_DATAPOINTS = 500
random.seed(0) # Set seed for reproducibility
with open("data/nounlist.txt", "r") as file:
topics = [line.strip() for line in file]
random.shuffle(topics)
print(f"There are {len(topics)} topics in the list of haiku topics.")
train_topics = topics[:NUM_TRAIN_DATAPOINTS]
val_topics = topics[NUM_TRAIN_DATAPOINTS : NUM_TRAIN_DATAPOINTS + NUM_VAL_DATAPOINTS]
print(f"Using {len(train_topics)} topics for training and {len(val_topics)} topics for validation.")
IMPORTANT: Reduce the number of concurrent requests if you're running into rate limits
MAX_CONCURRENT_REQUESTS = 50
tensorzero_client = await AsyncTensorZeroGateway.build_http(gateway_url=TENSORZERO_GATEWAY_URL, timeout=30)
async def write_judge_haiku(topic, variant_name):
# Generate a haiku about the given topic
try:
write_result = await tensorzero_client.inference(
function_name="write_haiku",
variant_name=variant_name, # only used during validation
input={
"messages": [
{
"role": "user",
"content": [{"type": "template", "name": "user", "arguments": {"topic": topic}}],
}
]
},
)
except Exception as e:
print(f"Error occurred: {type(e).__name__}: {e}")
return None
# The LLM is instructed to conclude with the haiku, so we extract the last 3 lines
# In a real application, you'll want more sophisticated validation and parsing logic
haiku_text = "\n".join(write_result.content[0].text.strip().split("\n")[-3:])
# Judge the haiku using a separate TensorZero function
# We use the same episode_id to associate these inferences
try:
judge_result = await tensorzero_client.inference(
function_name="judge_haiku",
input={
"messages": [
{
"role": "user",
"content": [
{
"type": "template",
"name": "user",
"arguments": {"topic": topic, "haiku": haiku_text},
}
],
}
]
},
episode_id=write_result.episode_id,
)
score = judge_result.output.parsed["score"]
except Exception as e:
print(f"Error occurred: {type(e).__name__}: {e}")
return None
return (write_result.inference_id, score)
# Run inference in parallel to speed things up
semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
async def ratelimited_write_judge_haiku(topic, variant_name=None):
async with semaphore:
return await write_judge_haiku(topic, variant_name=variant_name)
results = await tqdm_asyncio.gather(*[ratelimited_write_judge_haiku(topic) for topic in train_topics])
async def send_haiku_feedback(inference_id, score):
async with semaphore:
await tensorzero_client.feedback(metric_name="haiku_score", inference_id=inference_id, value=score)
await tqdm_asyncio.gather(*[send_haiku_feedback(*result) for result in results if result is not None]);
IMPORTANT: Update the list below when you create new variants in
tensorzero.toml
# Include the variants in `tensorzero.toml` that we want to evaluate
VARIANTS_TO_EVALUATE = [
"gpt_4o_mini",
# "gpt_4o_mini_fine_tuned",
]
scores = {} # variant_name => score
for variant_name in VARIANTS_TO_EVALUATE:
# Run inference on the validation set
val_results = await tqdm_asyncio.gather(
*[
ratelimited_write_judge_haiku(
topic,
variant_name=variant_name, # pin to the specific variant we want to evaluate
)
for topic in val_topics
],
desc=f"Evaluating variant: {variant_name}",
)
# Compute the average score for the variant
scores[variant_name] = sum(result[1] for result in val_results if result is not None) / len(val_results)
# Build a dataframe for plotting
scores_df = []
for variant_name, variant_score in scores.items():
scores_df.append(
{
"Variant": variant_name,
"Metric": "haiku_score",
"Score": variant_score,
}
)
scores_df = pd.DataFrame(scores_df)
# Build the chart
chart = (
alt.Chart(scores_df)
.encode(
x=alt.X("Score:Q", axis=alt.Axis(format="%"), scale=alt.Scale(domain=[0, 1])),
y="Variant:N",
color="Metric:N",
text=alt.Text("Score:Q", format=".1%"),
)
.properties(title="Score by Variant")
)
chart = chart.mark_bar() + chart.mark_text(align="left", dx=2)
chart