examples/babyai/babyai.ipynb
import asyncio
import random
from typing import List, Optional, Tuple
from uuid import UUID
import altair as alt
import pandas as pd
import yaml
from balrog.environments import make_env
from omegaconf import OmegaConf
from tensorzero import AsyncTensorZeroGateway
from tensorzero.util import uuid7
from tqdm import trange
Load config for BALROG environments
with open("config.yml") as f:
config_dict = yaml.safe_load(f)
config = OmegaConf.create(config_dict)
# Reduce this value if you're getting rate-limited by OpenAI
MAX_CONCURRENT_T0_REQUESTS = 50
semaphore = asyncio.Semaphore(MAX_CONCURRENT_T0_REQUESTS)
The run_episode function executes a single episode of the agent for a BabyAI task (maze navigation game).
async def run_episode(
client: AsyncTensorZeroGateway,
variant_name: str,
env_name: str,
task_name: str,
episode_idx: int,
config: OmegaConf,
semaphore: asyncio.Semaphore,
history_length: int = 2,
seed: int = 0,
test: bool = False,
) -> Tuple[float, float, Optional[UUID]]:
episode_log = {
"variant": variant_name,
"task": task_name,
"input_tokens": 0,
"output_tokens": 0,
}
use_history = "history" in variant_name
episode_id = uuid7()
env = make_env(env_name, task_name, config)
obs, _ = env.reset(seed=episode_idx + seed)
mission = obs["mission"]
episode_return = 0
history = []
for step in range(env.max_steps):
# Generate action
try:
async with semaphore:
# Generate message content
state = obs["text"]["long_term_context"]
# Generate action given message content
response = await client.inference(
function_name="act",
variant_name=variant_name,
input={
"system": {
"mission": mission,
},
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"arguments": {
"observation": state,
"history": "\n".join(history[-history_length:]),
},
}
],
}
],
},
episode_id=episode_id,
cache_options={"enabled": "on"},
)
episode_log["input_tokens"] += response.usage.input_tokens
episode_log["output_tokens"] += response.usage.output_tokens
action = response.output.parsed["action"]
# Check if action is valid and set to default if not
action = env.check_action_validity(action)
except Exception as e:
# Handle error
print(f"Error occurred: {type(e).__name__}: {e}")
print("Choosing a random legal move as fallback.")
action = random.choice(
[
"turn left",
"turn right",
"go forward",
"pick up",
"drop",
"toggle",
]
)
# Update history
if use_history:
history.append(f"Observation:{state}\n\nYour Response:\n{action}\n")
# Interact with environment
obs, reward, terminated, truncated, info = env.step(action)
# Update episode return
episode_return += reward
# Check if episode is done and break if so
done = terminated or truncated
if done:
break
# See if episode is successful
progression = env.get_stats()["progression"]
# Log feedback
await client.feedback(
metric_name="episode_return",
episode_id=episode_id,
value=episode_return,
dryrun=test,
)
await client.feedback(
metric_name="progression",
episode_id=episode_id,
value=progression,
dryrun=test,
)
episode_log["episode_return"] = episode_return
episode_log["num_steps"] = step + 1
episode_log["failed_candidates"] = env.failed_candidates
episode_log.update(env.get_stats())
episode_log["seed"] = episode_idx
episode_log["episode_id"] = episode_id
return episode_log
We define a function to run multiple episodes of the agent for a BabyAI task in parallel.
async def run_episodes(
client: AsyncTensorZeroGateway,
variant_name: str,
env_name: str,
task_name: str,
num_episodes: int,
config: OmegaConf,
semaphore: asyncio.Semaphore,
disable_progress_bar: bool = False,
history_length: int = 2,
seed: int = 0,
test: bool = False,
) -> Tuple[List[float], List[float]]:
progress_bar = trange(
num_episodes,
desc=f"{env_name} {task_name} {variant_name}",
disable=disable_progress_bar,
)
tasks = [
asyncio.create_task(
run_episode(
client=client,
variant_name=variant_name,
env_name=env_name,
task_name=task_name,
episode_idx=episode_idx,
config=config,
semaphore=semaphore,
history_length=history_length,
seed=seed,
test=test,
)
)
for episode_idx in range(num_episodes)
]
num_successes = 0
episode_logs = []
for task in asyncio.as_completed(tasks):
episode_log = await task
if episode_log["progression"] == 1.0:
num_successes += 1
episode_logs.append(episode_log)
current = len(episode_logs)
progress_bar.update(1)
progress_bar.set_postfix(
{"Success": f"{num_successes}/{current}"},
refresh=True,
)
progress_bar.close()
return episode_logs
seed = 200
num_episodes = 20
task_names = config.tasks.babyai_tasks
The baseline variant uses a simple system prompt that guides the LLM to navigate the maze.
You can find the prompts in config/functions/act/baseline.
results_baseline = []
for task_name in task_names:
async with await AsyncTensorZeroGateway.build_http(gateway_url="http://localhost:3000", timeout=180.0) as client:
results_task = await run_episodes(
client=client,
variant_name="baseline",
env_name="babyai",
task_name=task_name,
num_episodes=num_episodes,
config=config,
semaphore=semaphore,
disable_progress_bar=False,
seed=seed,
test=True,
)
results_baseline.extend(results_task)
The reasoning variant uses a system prompt that guides the LLM to reason about the best course of action.
You can find the prompts in config/functions/act/reasoning.
results_reasoning = []
for task_name in task_names:
async with await AsyncTensorZeroGateway.build_http(gateway_url="http://localhost:3000", timeout=180.0) as client:
results_task = await run_episodes(
client=client,
variant_name="reasoning",
env_name="babyai",
task_name=task_name,
num_episodes=num_episodes,
config=config,
semaphore=semaphore,
disable_progress_bar=False,
seed=seed,
test=True,
)
results_reasoning.extend(results_task)
The history variant uses the previous observations and actions to guide the LLM to navigate the maze.
We add the previous two observations and actions to the field history in the examples below.
You can find the prompts in config/functions/act/history.
history_length = 8
results_history = []
for task_name in task_names:
async with await AsyncTensorZeroGateway.build_http(gateway_url="http://localhost:3000", timeout=180.0) as client:
results_task = await run_episodes(
client=client,
variant_name="history",
env_name="babyai",
task_name=task_name,
num_episodes=num_episodes,
config=config,
semaphore=semaphore,
disable_progress_bar=False,
history_length=history_length,
seed=seed,
test=True,
)
results_history.extend(results_task)
The history_and_reasoning variant combines the reasoning variant and the history variant.
You can find the prompts in config/functions/act/history_and_reasoning.
results_history_and_reasoning = []
for task_name in task_names:
async with await AsyncTensorZeroGateway.build_http(gateway_url="http://localhost:3000", timeout=180.0) as client:
results_task = await run_episodes(
client=client,
variant_name="history_and_reasoning",
env_name="babyai",
task_name=task_name,
num_episodes=num_episodes,
config=config,
semaphore=semaphore,
disable_progress_bar=False,
history_length=history_length,
seed=seed,
test=True,
)
results_history_and_reasoning.extend(results_task)
df = pd.DataFrame(results_baseline + results_reasoning + results_history + results_history_and_reasoning)
summary = df.groupby("variant")["progression"].agg(["mean", "sem"]).reset_index()
# Create a base chart
bars = (
alt.Chart(summary)
.encode(
y=alt.Y("variant:N", title="Variant"),
x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
)
.mark_bar(color="#1f77b4")
)
# Create error bars
error_bars = (
alt.Chart(summary)
.mark_errorbar(color="black")
.encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
.transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)
# Combine the layers
chart = (bars + error_bars).properties(title="Task Success Rate")
chart
summary = df.groupby("variant")["episode_return"].agg(["mean", "sem"]).reset_index()
# Create a base chart
bars = (
alt.Chart(summary)
.encode(
y=alt.Y("variant:N", title="Variant"),
x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
)
.mark_bar(color="#1f77b4")
)
# Create error bars
error_bars = (
alt.Chart(summary)
.mark_errorbar(color="black")
.encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
.transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)
# Combine the layers
chart = (bars + error_bars).properties(title="Episode Return")
chart
summary = df.groupby("variant")["num_steps"].agg(["mean", "sem"]).reset_index()
# Create a base chart
bars = (
alt.Chart(summary)
.encode(
y=alt.Y("variant:N", title="Variant"),
x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
)
.mark_bar(color="#1f77b4")
)
# Create error bars
error_bars = (
alt.Chart(summary)
.mark_errorbar(color="black")
.encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
.transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)
# Combine the layers
chart = (bars + error_bars).properties(title="Episode Length")
chart
summary = df.groupby("variant")["output_tokens"].agg(["mean", "sem"]).reset_index()
# Create a base chart
bars = (
alt.Chart(summary)
.encode(
y=alt.Y("variant:N", title="Variant"),
x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
)
.mark_bar(color="#1f77b4")
)
# Create error bars
error_bars = (
alt.Chart(summary)
.mark_errorbar(color="black")
.encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
.transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)
# Combine the layers
chart = (bars + error_bars).properties(title="Episode Generated Token Count")
chart
summary = df.groupby("variant")["input_tokens"].agg(["mean", "sem"]).reset_index()
# Create a base chart
bars = (
alt.Chart(summary)
.encode(
y=alt.Y("variant:N", title="Variant"),
x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
)
.mark_bar(color="#1f77b4")
)
# Create error bars
error_bars = (
alt.Chart(summary)
.mark_errorbar(color="black")
.encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
.transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)
# Combine the layers
chart = (bars + error_bars).properties(title="Episode Input Token Count")
chart
The results above show that the history_and_reasoning variant yields the best success rate.
Here we describe how to improve the performance of the history_and_reasoning variant by fine-tuning it on a separate set of random episodes.
First we run a large set of episodes for each task using the history_and_reasoning variant to generate data for fine-tuning.
num_episodes_ft = 200
seed_ft = 0
for task_name in task_names:
async with await AsyncTensorZeroGateway.build_http(gateway_url="http://localhost:3000", timeout=180.0) as client:
results_task = await run_episodes(
client=client,
variant_name="history_and_reasoning",
env_name="babyai",
task_name=task_name,
num_episodes=num_episodes_ft,
config=config,
semaphore=semaphore,
disable_progress_bar=False,
history_length=history_length,
seed=seed_ft,
test=False,
)
We provide two option for fine-tuning a model: using a notebook or using the TensorZero UI. You can fine-tune on episodes that successfully completed the task, or episodes that achieved a sufficiently high return (e.g. 0.7).
See the README.md file for more details.
After fine-tuning, create a history_and_reasoning_sft variant and run the following code to evaluate it.
results_history_and_reasoning_ft = []
for task_name in task_names:
async with await AsyncTensorZeroGateway.build_http(gateway_url="http://localhost:3000", timeout=180.0) as client:
results_task = await run_episodes(
client=client,
variant_name="history_and_reasoning_sft",
env_name="babyai",
task_name=task_name,
num_episodes=num_episodes,
config=config,
semaphore=semaphore,
disable_progress_bar=False,
history_length=history_length,
seed=seed,
test=True,
)
results_history_and_reasoning_ft.extend(results_task)
We combine the results of the fine-tuned model with the results of the history_and_reasoning variant.
df_ft = pd.DataFrame(results_history_and_reasoning_ft)
df = pd.concat([df, df_ft])
We see below that the fine-tuned model performs better than the history_and_reasoning variant!
summary = df.groupby("variant")["progression"].agg(["mean", "sem"]).reset_index()
# Create a base chart
bars = (
alt.Chart(summary)
.encode(
y=alt.Y("variant:N", title="Variant"),
x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
)
.mark_bar(color="#1f77b4")
)
# Create error bars
error_bars = (
alt.Chart(summary)
.mark_errorbar(color="black")
.encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
.transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)
# Combine the layers
chart = (bars + error_bars).properties(title="Task Success Rate")
chart
summary = df.groupby("variant")["episode_return"].agg(["mean", "sem"]).reset_index()
# Create a base chart
bars = (
alt.Chart(summary)
.encode(
y=alt.Y("variant:N", title="Variant"),
x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
)
.mark_bar(color="#1f77b4")
)
# Create error bars
error_bars = (
alt.Chart(summary)
.mark_errorbar(color="black")
.encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
.transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)
# Combine the layers
chart = (bars + error_bars).properties(title="Episode Return")
chart
summary = df.groupby("variant")["num_steps"].agg(["mean", "sem"]).reset_index()
# Create a base chart
bars = (
alt.Chart(summary)
.encode(
y=alt.Y("variant:N", title="Variant"),
x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
)
.mark_bar(color="#1f77b4")
)
# Create error bars
error_bars = (
alt.Chart(summary)
.mark_errorbar(color="black")
.encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
.transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)
# Combine the layers
chart = (bars + error_bars).properties(title="Episode Length")
chart
summary = df.groupby("variant")["output_tokens"].agg(["mean", "sem"]).reset_index()
# Create a base chart
bars = (
alt.Chart(summary)
.encode(
y=alt.Y("variant:N", title="Variant"),
x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
)
.mark_bar(color="#1f77b4")
)
# Create error bars
error_bars = (
alt.Chart(summary)
.mark_errorbar(color="black")
.encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
.transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)
# Combine the layers
chart = (bars + error_bars).properties(title="Episode Generated Token Count")
chart
summary = df.groupby("variant")["input_tokens"].agg(["mean", "sem"]).reset_index()
# Create a base chart
bars = (
alt.Chart(summary)
.encode(
y=alt.Y("variant:N", title="Variant"),
x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
)
.mark_bar(color="#1f77b4")
)
# Create error bars
error_bars = (
alt.Chart(summary)
.mark_errorbar(color="black")
.encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
.transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)
# Combine the layers
chart = (bars + error_bars).properties(title="Episode Input Token Count")
chart