Back to Tensorzero

Example: LLMs Learn to Navigate Mazes from Experience (BabyAI Benchmark)

examples/babyai/babyai.ipynb

2026.6.019.2 KB
Original Source

Example: LLMs Learn to Navigate Mazes from Experience (BabyAI Benchmark)

Setup

python
import asyncio
import random
from typing import List, Optional, Tuple
from uuid import UUID

import altair as alt
import pandas as pd
import yaml
from balrog.environments import make_env
from omegaconf import OmegaConf
from tensorzero import AsyncTensorZeroGateway
from tensorzero.util import uuid7
from tqdm import trange

Load config for BALROG environments

python
with open("config.yml") as f:
    config_dict = yaml.safe_load(f)
config = OmegaConf.create(config_dict)
python
# Reduce this value if you're getting rate-limited by OpenAI
MAX_CONCURRENT_T0_REQUESTS = 50
semaphore = asyncio.Semaphore(MAX_CONCURRENT_T0_REQUESTS)

Helper Functions

The run_episode function executes a single episode of the agent for a BabyAI task (maze navigation game).

python
async def run_episode(
    client: AsyncTensorZeroGateway,
    variant_name: str,
    env_name: str,
    task_name: str,
    episode_idx: int,
    config: OmegaConf,
    semaphore: asyncio.Semaphore,
    history_length: int = 2,
    seed: int = 0,
    test: bool = False,
) -> Tuple[float, float, Optional[UUID]]:
    episode_log = {
        "variant": variant_name,
        "task": task_name,
        "input_tokens": 0,
        "output_tokens": 0,
    }
    use_history = "history" in variant_name
    episode_id = uuid7()
    env = make_env(env_name, task_name, config)
    obs, _ = env.reset(seed=episode_idx + seed)
    mission = obs["mission"]
    episode_return = 0
    history = []
    for step in range(env.max_steps):
        # Generate action
        try:
            async with semaphore:
                # Generate message content
                state = obs["text"]["long_term_context"]
                # Generate action given message content
                response = await client.inference(
                    function_name="act",
                    variant_name=variant_name,
                    input={
                        "system": {
                            "mission": mission,
                        },
                        "messages": [
                            {
                                "role": "user",
                                "content": [
                                    {
                                        "type": "text",
                                        "arguments": {
                                            "observation": state,
                                            "history": "\n".join(history[-history_length:]),
                                        },
                                    }
                                ],
                            }
                        ],
                    },
                    episode_id=episode_id,
                    cache_options={"enabled": "on"},
                )
                episode_log["input_tokens"] += response.usage.input_tokens
                episode_log["output_tokens"] += response.usage.output_tokens
            action = response.output.parsed["action"]
            # Check if action is valid and set to default if not
            action = env.check_action_validity(action)
        except Exception as e:
            # Handle error
            print(f"Error occurred: {type(e).__name__}: {e}")
            print("Choosing a random legal move as fallback.")
            action = random.choice(
                [
                    "turn left",
                    "turn right",
                    "go forward",
                    "pick up",
                    "drop",
                    "toggle",
                ]
            )
        # Update history
        if use_history:
            history.append(f"Observation:{state}\n\nYour Response:\n{action}\n")
        # Interact with environment
        obs, reward, terminated, truncated, info = env.step(action)
        # Update episode return
        episode_return += reward
        # Check if episode is done and break if so
        done = terminated or truncated
        if done:
            break
    # See if episode is successful
    progression = env.get_stats()["progression"]
    # Log feedback
    await client.feedback(
        metric_name="episode_return",
        episode_id=episode_id,
        value=episode_return,
        dryrun=test,
    )
    await client.feedback(
        metric_name="progression",
        episode_id=episode_id,
        value=progression,
        dryrun=test,
    )
    episode_log["episode_return"] = episode_return
    episode_log["num_steps"] = step + 1
    episode_log["failed_candidates"] = env.failed_candidates
    episode_log.update(env.get_stats())
    episode_log["seed"] = episode_idx
    episode_log["episode_id"] = episode_id
    return episode_log

We define a function to run multiple episodes of the agent for a BabyAI task in parallel.

python
async def run_episodes(
    client: AsyncTensorZeroGateway,
    variant_name: str,
    env_name: str,
    task_name: str,
    num_episodes: int,
    config: OmegaConf,
    semaphore: asyncio.Semaphore,
    disable_progress_bar: bool = False,
    history_length: int = 2,
    seed: int = 0,
    test: bool = False,
) -> Tuple[List[float], List[float]]:
    progress_bar = trange(
        num_episodes,
        desc=f"{env_name} {task_name} {variant_name}",
        disable=disable_progress_bar,
    )

    tasks = [
        asyncio.create_task(
            run_episode(
                client=client,
                variant_name=variant_name,
                env_name=env_name,
                task_name=task_name,
                episode_idx=episode_idx,
                config=config,
                semaphore=semaphore,
                history_length=history_length,
                seed=seed,
                test=test,
            )
        )
        for episode_idx in range(num_episodes)
    ]

    num_successes = 0
    episode_logs = []
    for task in asyncio.as_completed(tasks):
        episode_log = await task
        if episode_log["progression"] == 1.0:
            num_successes += 1
        episode_logs.append(episode_log)
        current = len(episode_logs)
        progress_bar.update(1)
        progress_bar.set_postfix(
            {"Success": f"{num_successes}/{current}"},
            refresh=True,
        )
    progress_bar.close()
    return episode_logs
python
seed = 200
num_episodes = 20
task_names = config.tasks.babyai_tasks

Baseline

The baseline variant uses a simple system prompt that guides the LLM to navigate the maze.

You can find the prompts in config/functions/act/baseline.

python
results_baseline = []

for task_name in task_names:
    async with await AsyncTensorZeroGateway.build_http(gateway_url="http://localhost:3000", timeout=180.0) as client:
        results_task = await run_episodes(
            client=client,
            variant_name="baseline",
            env_name="babyai",
            task_name=task_name,
            num_episodes=num_episodes,
            config=config,
            semaphore=semaphore,
            disable_progress_bar=False,
            seed=seed,
            test=True,
        )
        results_baseline.extend(results_task)

Reasoning

The reasoning variant uses a system prompt that guides the LLM to reason about the best course of action.

You can find the prompts in config/functions/act/reasoning.

python
results_reasoning = []
for task_name in task_names:
    async with await AsyncTensorZeroGateway.build_http(gateway_url="http://localhost:3000", timeout=180.0) as client:
        results_task = await run_episodes(
            client=client,
            variant_name="reasoning",
            env_name="babyai",
            task_name=task_name,
            num_episodes=num_episodes,
            config=config,
            semaphore=semaphore,
            disable_progress_bar=False,
            seed=seed,
            test=True,
        )
        results_reasoning.extend(results_task)

History

The history variant uses the previous observations and actions to guide the LLM to navigate the maze. We add the previous two observations and actions to the field history in the examples below.

You can find the prompts in config/functions/act/history.

python
history_length = 8

results_history = []
for task_name in task_names:
    async with await AsyncTensorZeroGateway.build_http(gateway_url="http://localhost:3000", timeout=180.0) as client:
        results_task = await run_episodes(
            client=client,
            variant_name="history",
            env_name="babyai",
            task_name=task_name,
            num_episodes=num_episodes,
            config=config,
            semaphore=semaphore,
            disable_progress_bar=False,
            history_length=history_length,
            seed=seed,
            test=True,
        )
        results_history.extend(results_task)

History and Reasoning

The history_and_reasoning variant combines the reasoning variant and the history variant.

You can find the prompts in config/functions/act/history_and_reasoning.

python
results_history_and_reasoning = []
for task_name in task_names:
    async with await AsyncTensorZeroGateway.build_http(gateway_url="http://localhost:3000", timeout=180.0) as client:
        results_task = await run_episodes(
            client=client,
            variant_name="history_and_reasoning",
            env_name="babyai",
            task_name=task_name,
            num_episodes=num_episodes,
            config=config,
            semaphore=semaphore,
            disable_progress_bar=False,
            history_length=history_length,
            seed=seed,
            test=True,
        )
        results_history_and_reasoning.extend(results_task)

Results

python
df = pd.DataFrame(results_baseline + results_reasoning + results_history + results_history_and_reasoning)

Success Rate

python
summary = df.groupby("variant")["progression"].agg(["mean", "sem"]).reset_index()

# Create a base chart
bars = (
    alt.Chart(summary)
    .encode(
        y=alt.Y("variant:N", title="Variant"),
        x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
    )
    .mark_bar(color="#1f77b4")
)

# Create error bars
error_bars = (
    alt.Chart(summary)
    .mark_errorbar(color="black")
    .encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
    .transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)

# Combine the layers
chart = (bars + error_bars).properties(title="Task Success Rate")

chart

Episode Return

python
summary = df.groupby("variant")["episode_return"].agg(["mean", "sem"]).reset_index()

# Create a base chart
bars = (
    alt.Chart(summary)
    .encode(
        y=alt.Y("variant:N", title="Variant"),
        x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
    )
    .mark_bar(color="#1f77b4")
)

# Create error bars
error_bars = (
    alt.Chart(summary)
    .mark_errorbar(color="black")
    .encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
    .transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)

# Combine the layers
chart = (bars + error_bars).properties(title="Episode Return")

chart

Episode Length

python
summary = df.groupby("variant")["num_steps"].agg(["mean", "sem"]).reset_index()

# Create a base chart
bars = (
    alt.Chart(summary)
    .encode(
        y=alt.Y("variant:N", title="Variant"),
        x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
    )
    .mark_bar(color="#1f77b4")
)

# Create error bars
error_bars = (
    alt.Chart(summary)
    .mark_errorbar(color="black")
    .encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
    .transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)

# Combine the layers
chart = (bars + error_bars).properties(title="Episode Length")

chart

Episode Generated Token Count

python
summary = df.groupby("variant")["output_tokens"].agg(["mean", "sem"]).reset_index()

# Create a base chart
bars = (
    alt.Chart(summary)
    .encode(
        y=alt.Y("variant:N", title="Variant"),
        x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
    )
    .mark_bar(color="#1f77b4")
)

# Create error bars
error_bars = (
    alt.Chart(summary)
    .mark_errorbar(color="black")
    .encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
    .transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)

# Combine the layers
chart = (bars + error_bars).properties(title="Episode Generated Token Count")

chart

Episode Input Token Count

python
summary = df.groupby("variant")["input_tokens"].agg(["mean", "sem"]).reset_index()

# Create a base chart
bars = (
    alt.Chart(summary)
    .encode(
        y=alt.Y("variant:N", title="Variant"),
        x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
    )
    .mark_bar(color="#1f77b4")
)

# Create error bars
error_bars = (
    alt.Chart(summary)
    .mark_errorbar(color="black")
    .encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
    .transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)

# Combine the layers
chart = (bars + error_bars).properties(title="Episode Input Token Count")

chart

Improving Performance with Supervised Fine-tuning (SFT)

The results above show that the history_and_reasoning variant yields the best success rate. Here we describe how to improve the performance of the history_and_reasoning variant by fine-tuning it on a separate set of random episodes.

First we run a large set of episodes for each task using the history_and_reasoning variant to generate data for fine-tuning.

python
num_episodes_ft = 200
seed_ft = 0

for task_name in task_names:
    async with await AsyncTensorZeroGateway.build_http(gateway_url="http://localhost:3000", timeout=180.0) as client:
        results_task = await run_episodes(
            client=client,
            variant_name="history_and_reasoning",
            env_name="babyai",
            task_name=task_name,
            num_episodes=num_episodes_ft,
            config=config,
            semaphore=semaphore,
            disable_progress_bar=False,
            history_length=history_length,
            seed=seed_ft,
            test=False,
        )

We provide two option for fine-tuning a model: using a notebook or using the TensorZero UI. You can fine-tune on episodes that successfully completed the task, or episodes that achieved a sufficiently high return (e.g. 0.7).

See the README.md file for more details.

Evaluating the Fine-tuned Variant

After fine-tuning, create a history_and_reasoning_sft variant and run the following code to evaluate it.

python
results_history_and_reasoning_ft = []
for task_name in task_names:
    async with await AsyncTensorZeroGateway.build_http(gateway_url="http://localhost:3000", timeout=180.0) as client:
        results_task = await run_episodes(
            client=client,
            variant_name="history_and_reasoning_sft",
            env_name="babyai",
            task_name=task_name,
            num_episodes=num_episodes,
            config=config,
            semaphore=semaphore,
            disable_progress_bar=False,
            history_length=history_length,
            seed=seed,
            test=True,
        )
        results_history_and_reasoning_ft.extend(results_task)

Results

We combine the results of the fine-tuned model with the results of the history_and_reasoning variant.

python
df_ft = pd.DataFrame(results_history_and_reasoning_ft)

df = pd.concat([df, df_ft])

We see below that the fine-tuned model performs better than the history_and_reasoning variant!

Success Rate

python
summary = df.groupby("variant")["progression"].agg(["mean", "sem"]).reset_index()

# Create a base chart
bars = (
    alt.Chart(summary)
    .encode(
        y=alt.Y("variant:N", title="Variant"),
        x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
    )
    .mark_bar(color="#1f77b4")
)

# Create error bars
error_bars = (
    alt.Chart(summary)
    .mark_errorbar(color="black")
    .encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
    .transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)

# Combine the layers
chart = (bars + error_bars).properties(title="Task Success Rate")

chart

Episode Return

python
summary = df.groupby("variant")["episode_return"].agg(["mean", "sem"]).reset_index()

# Create a base chart
bars = (
    alt.Chart(summary)
    .encode(
        y=alt.Y("variant:N", title="Variant"),
        x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
    )
    .mark_bar(color="#1f77b4")
)

# Create error bars
error_bars = (
    alt.Chart(summary)
    .mark_errorbar(color="black")
    .encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
    .transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)

# Combine the layers
chart = (bars + error_bars).properties(title="Episode Return")

chart

Episode Length

python
summary = df.groupby("variant")["num_steps"].agg(["mean", "sem"]).reset_index()

# Create a base chart
bars = (
    alt.Chart(summary)
    .encode(
        y=alt.Y("variant:N", title="Variant"),
        x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
    )
    .mark_bar(color="#1f77b4")
)

# Create error bars
error_bars = (
    alt.Chart(summary)
    .mark_errorbar(color="black")
    .encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
    .transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)

# Combine the layers
chart = (bars + error_bars).properties(title="Episode Length")

chart

Episode Generated Token Count

python
summary = df.groupby("variant")["output_tokens"].agg(["mean", "sem"]).reset_index()

# Create a base chart
bars = (
    alt.Chart(summary)
    .encode(
        y=alt.Y("variant:N", title="Variant"),
        x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
    )
    .mark_bar(color="#1f77b4")
)

# Create error bars
error_bars = (
    alt.Chart(summary)
    .mark_errorbar(color="black")
    .encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
    .transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)

# Combine the layers
chart = (bars + error_bars).properties(title="Episode Generated Token Count")

chart

Episode Input Token Count

python
summary = df.groupby("variant")["input_tokens"].agg(["mean", "sem"]).reset_index()

# Create a base chart
bars = (
    alt.Chart(summary)
    .encode(
        y=alt.Y("variant:N", title="Variant"),
        x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
    )
    .mark_bar(color="#1f77b4")
)

# Create error bars
error_bars = (
    alt.Chart(summary)
    .mark_errorbar(color="black")
    .encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
    .transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)

# Combine the layers
chart = (bars + error_bars).properties(title="Episode Input Token Count")

chart