Back to Trl

Agent Training with GRPO using TRL

examples/notebooks/grpo_agent.ipynb

1.3.018.1 KB
Original Source

Agent Training with GRPO using TRL

With Transformers Reinforcement Learning (TRL), you can train a language model to act as an agent. One that learns to reason, interact with external tools, and improve through reinforcement.

TRL supports training agents that can use external tools as part of their decision process.
In this notebook, the agent has access to the BioGRID database, which it can query using read-only SQL commands to retrieve biological interaction data. The model learns when and how to use tools based on rewards.

We'll fine-tune a model using GRPO (Group Relative Policy Optimization) via TRL. The agent will:

  1. Generate tool call to query the database if needed.
  2. Receive the tool response and add it it to the context.
  3. Learn to improve its tool usage and general capabilities over time through reward signals.

Install dependencies

We'll start by installing TRL, which automatically includes the main dependencies like Transformers.
We'll also install trackio (for logging and monitoring training runs), vLLM (for efficient generation), and jmespath (needed for the tools capabilities).

python
!pip install -Uq "trl[vllm]" git+https://github.com/huggingface/transformers.git trackio jmespath 

Log in to Hugging Face

Log in to your Hugging Face account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your access token on your account settings page.

python
from huggingface_hub import notebook_login

notebook_login()

Create the database for the tool

For this example, we will use the BioGRID database, a curated resource containing protein, genetic, and chemical interaction data. We've already compiled and uploaded it to the Hub at qgallouedec/biogrid. The dataset is loaded and converted into an sqlite database.

💡 We remove spaces in the column names to easen the model work. In real-world deployments, you may keep your original column names and rely on the agent to reason about them. Here, we simplify the schema to make training smoother.

python
import sqlite3
from datasets import load_dataset

# Load dataset
biogrid_dataset = load_dataset("qgallouedec/biogrid", split="train")
df = biogrid_dataset.to_pandas()

# Normalize column names: remove spaces, replace with underscores
df.columns = [c.replace(" ", "_") for c in df.columns]

# Save to SQLite
conn = sqlite3.connect("biogrid.db")
try:
    df.to_sql("interactions", conn, if_exists="replace", index=False)
    print(f"biogrid.db created. Rows stored: {len(df)}")
finally:
    conn.close()

Load the QA dataset

The training objective is to fine-tune a model to answer gene-related questions. The model should learn to use the database query tool to retrieve factual information when needed.

We'll define a formatting function for each sample, adding instructions about the database and how to call it. The model must answer with yes or no. Let's implement the format_example function.

python
import textwrap

def format_example(example):
    question = example["question"]
    preamble = textwrap.dedent("""\
    You have access to the BioGRID SQLite database.
    Use SQL queries to retrieve only the information needed to answer the question.

    Genes may appear in the database in columns `Alt_IDs_Interactor_A` `Alt_IDs_Interactor_B`, `Aliases_Interactor_A` and `Aliases_Interactor_B`,
    and each entry can contain multiple gene names or synonyms separated by '|', for example:
    'entrez gene/locuslink:JNKK(gene name synonym)|entrez gene/locuslink:MAPKK4(gene name synonym)|...'
    So a gene like 'JNKK' or 'MAPKK4' may appear inside one of these strings.

    If the database schema is unclear or you are unsure about column names:
    - First inspect the schema with `PRAGMA table_info(interactions);`
    - Or preview a few rows with `SELECT * FROM interactions LIMIT 1;`

    Otherwise, directly query the required data.

    Final answer must be enclosed in stars, e.g. *Yes* or *No*.
    Facts:
    - The NCBI Taxonomy identifier for humans is taxid:9606.
    """)
    content = f"{preamble}\nQuestion: {question}"
    prompt = [{"role": "user", "content": content}]
    return {"prompt": prompt}

Now, let's load the database and call the previous function.
For simplicity, we will only use questions that start with “Does the gene…”.
In a real use case, the full dataset can be used.

The QA dataset is available on the Hub.

python
dataset = load_dataset("qgallouedec/biogrid_qa", split="train")
dataset = dataset.filter(
    lambda example: example["question"].startswith("Does the gene ")
)  # keep only simple questions for example
dataset = dataset.map(format_example, remove_columns=["question"])

train_dataset = dataset
eval_dataset = None  # No eval by default, can be added if needed

Create tool for the agent

The query_biogrid function is the tool the model will use to query the database and retrieve factual information.
Each tool must be a standard Python function with type-hinted arguments and return types, and a Google-style docstring describing its purpose, parameters, and return value.

python
from contextlib import contextmanager
import signal

@contextmanager
def timeout(seconds):
    """Context manager that raises TimeoutError if execution exceeds time limit."""

    def timeout_handler(signum, frame):
        raise TimeoutError(f"Operation timed out after {seconds} seconds")

    signal.signal(signal.SIGALRM, timeout_handler)
    signal.alarm(seconds)
    try:
        yield
    finally:
        signal.alarm(0)

def query_biogrid(sql_command: str) -> list[tuple]:
    """
    Execute a read-only SQL command on the BioGRID database.

    BioGRID is a curated biological database that compiles protein, genetic, and chemical interactions from multiple organisms. It provides researchers with experimentally verified interaction data to support studies in systems biology and functional genomics.

    Args:
        sql_command: The SQL command to execute.

    Returns:
        A list of tuples containing the query results.
    """
    with timeout(5):
        conn = sqlite3.connect("file:biogrid.db?mode=ro", uri=True)
        cursor = conn.cursor()
        try:
            cursor.execute(sql_command)
            results = cursor.fetchall()
        finally:
            conn.close()
    return results

Define reward functions

To guide the agent during training, we define a few simple reward functions:

  • query_reward: evaluates the model’s query strategy — penalizes more than two queries, penalizes generic database scans, and rewards use of WHERE and evidence supporting the final answer.
  • correctness_reward: rewards Yes/No predictions that match the expected answer.
  • structure_reward: rewards a proper assistant structure (tool call → response → optional explanation).

Each function returns a list of floats used by the GRPOTrainer during optimization.
Combined, they encourage effective tool use and factual answers.

python
import re

def query_reward(completions, answer, **kwargs):
    """
    Reward query strategy:
    - Penalize more than 2 queries
    - Penalize generic queries (LIMIT 1 / PRAGMA)
    - Reward usage of WHERE
    - Reward evidence supporting the final answer
    """
    rewards = []

    for completion, ans in zip(completions, answer, strict=False):
        reward = 0.0
        sql_queries = []
        tool_results = []

        # collect all SQL queries and tool results
        for turn in completion:
            if turn.get("tool_calls"):
                for call in turn["tool_calls"]:
                    sql = call["function"]["arguments"].get("sql_command", "").lower()
                    sql_queries.append(sql)
            if turn.get("role") == "tool" and turn.get("content"):
                tool_results.append(turn["content"])

        # --- penalize too many queries ---
        if len(sql_queries) > 3:
            reward -= 1.5

        # --- check query quality ---
        where_count = 0
        for q in sql_queries:
            if "limit 1" in q:
                reward -= 1.0
            if " where " not in q:
                reward -= 0.5
            else:
                where_count += 1
        reward += min(where_count, 3) * 0.4  # small bonus for WHERE usage

        # --- evidence check: do queries support the answer? ---
        combined_results = []
        error_detected = False

        for res in tool_results:
            if isinstance(res, dict) and "error" in res:
                error_detected = True
            elif isinstance(res, list):
                combined_results.extend(res)

        # if error detected, penalize heavily
        if error_detected:
            reward -= 2.0
        elif len(sql_queries) == 0:
            reward -= 1.5
        else:
            has_hits = len(combined_results) > 0
            correct_answer = ans.lower()
            if (has_hits and correct_answer == "yes") or (not has_hits and correct_answer == "no"):
                reward += 2.0
            else:
                reward -= 1.5

        rewards.append(reward)

    return rewards


def correctness_reward(completions, answer, **kwargs):
    """
    Reward Yes/No correctness.
    Model must provide final answer enclosed in stars — *yes* or *no*.
    Does not reward informal yes/no buried in text.
    """
    rewards = []
    for completion, ans in zip(completions, answer, strict=False):
        raw = completion[-1]["content"].lower()

        # detect form *yes* or *no*
        match = re.search(r"\*(yes|no)\*", raw)
        guess = match.group(1) if match else None

        reward = 0.0

        if guess is None:
            reward -= 0.5  # invalid format
        elif guess == ans.lower():
            reward += 0.6  # correct under required format
        else:
            reward -= 1.0  # wrong answer

        rewards.append(reward)

    return rewards


def structure_reward(completions, **kwargs):
    """
    Reward proper assistant structure.
    Encourages a logical sequence: tool call + response + optional extra content.
    """
    rewards = []

    for completion in completions:
        has_call = False
        has_response = False
        has_other = False

        for turn in completion:
            role = turn.get("role")
            if role == "assistant" and turn.get("tool_calls"):
                has_call = True
            elif role == "tool":
                has_response = True
            else:
                content = turn.get("content")
                if content and content.strip() not in ["", "<think>"]:
                    has_other = True

        # Reward sequences
        if has_call and has_response:
            if has_other:
                reward = 0.1
            else:
                reward = 0.05  # still positive even without extra text
        elif has_call and not has_response:
            reward = -0.15
        else:
            reward = 0.0  # neutral if no call

        rewards.append(reward)

    return rewards

Set GRPO Config

Next, we define the GRPOConfig, which controls the main training parameters.
This configuration specifies how the model interacts with vLLM, manages memory, and logs results.

python
from trl import GRPOConfig

output_dir = "grpo_biogrid_qwen_3g-1.7b"

grpo_config = GRPOConfig(
    # Training schedule / optimization
    max_steps=400,                                              # Max number of training steps
    chat_template_kwargs = {"enable_thinking": False},          # Disable thinking to reduce token generation

    # GRPO configuration
    max_completion_length = 1024,                               # Maximum tokens generated per model response

    # vLLM configuration
    use_vllm = True,                                            # Enable vLLM for faster inference during rollouts
    vllm_mode = "colocate",                                     # Run vLLM in colocate mode (same process as training)
    vllm_enable_sleep_mode=False,

    # Logging / reporting
    output_dir = output_dir,                                    # Directory for checkpoints and logs
    report_to="trackio",                                        # Experiment tracking tool (integrates with HF Spaces)
    trackio_space_id = output_dir,                              # HF Space where experiment tracking will be saved
    save_steps = 10,                                            # Interval for saving checkpoints
    log_completions = True,

    # Hub integration
    push_to_hub = True,                                         # Set True to automatically push model to Hugging Face Hub
)

Create GRPOTrainer and Start Training

Next, we initialize the GRPOTrainer, which handles the full reinforcement learning loop.

It receives the model name, reward functions, tool(s), and dataset defined earlier.

Finally, we call trainer.train() to begin fine-tuning, allowing the model to learn how to query the database effectively through iterative feedback.

python
from trl import GRPOTrainer

model_name="Qwen/Qwen3-1.7B"

trainer = GRPOTrainer(
    model=model_name,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tools=[query_biogrid],
    reward_funcs=[correctness_reward, structure_reward, query_reward],
    args=grpo_config,
)

Show memory stats before training

python
import torch
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)

print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

And train!

python
trainer_stats = trainer.train()

Show memory stats after training

python
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)

print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.")
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

Let's save the trained model.

python
trainer.save_model(output_dir)
trainer.push_to_hub()

Load the fine-tuned model and run inference using smolagents

After fine-tuning the model with GRPO (TRL) for tool calling, we can test it at inference time using smolagents, a lightweight library for running multi-step agents.

smolagents handles the agent loop for us:

  • Detecting tool calls generated by the model
  • Executing the corresponding tools (e.g. database queries)
  • Feeding the results back to the model until a final answer is produced

Note
Using an agent framework is optional. The fine-tuned model can also be used directly with transformers by manually controlling the inference loop and executing the tools outside the model. Agent frameworks are especially useful when the number of steps or tool calls is not fixed.

We start by installing the required package:

python
!pip install git+https://github.com/huggingface/smolagents.git

We will use the CodeAgent class from smolagents to instantiate our agent.
First, we need to define the tool the agent can use. This is done using the @tool decorator.

As shown below, the tool definition is exactly the same as the one used during GRPO training with TRL. This consistency is important: the model was trained to emit calls following this schema, and at inference time the agent simply executes the corresponding Python function.

python
from smolagents import tool

@tool
def query_biogrid(sql_command: str) -> list[tuple]:
    """
    Execute a read-only SQL query on the BioGRID database.

    BioGRID is a curated biological database that compiles protein, genetic,
    and chemical interactions from multiple organisms.

    Args:
        sql_command: A read-only SQL query to execute.

    Returns:
        A list of tuples containing the query results.
    """
    with timeout(5):
        conn = sqlite3.connect(
            "file:biogrid.db?mode=ro",
            uri=True,
        )
        cursor = conn.cursor()
        try:
            cursor.execute(sql_command)
            results = cursor.fetchall()
        finally:
            conn.close()

    return results

Now we can instantiate the agent using our fine-tuned model and the database tool defined above. We wrap the model with TransformersModel and pass both the model and the tool when creating the CodeAgent.

python
from smolagents import TransformersModel, CodeAgent

model = TransformersModel(model_id="sergiopaniego/grpo_biogrid_qwen_3g-1.7b", apply_chat_template_kwargs={"enable_thinking": False})

# Create an agent with query_biogrid as tool
agent = CodeAgent(tools=[query_biogrid], model=model)

Finally, we run the agent by passing the full prompt (including the instruction preamble and the question), exactly as it was used during training. This ensures the agent operates under the same context and assumptions learned with GRPO, allowing it to correctly decide when to query the database and how to format the final answer.

python
result = agent.run(train_dataset[0]['prompt'][0]['content'])
print(result)