examples/notebooks/grpo_agent.ipynb
With Transformers Reinforcement Learning (TRL), you can train a language model to act as an agent. One that learns to reason, interact with external tools, and improve through reinforcement.
TRL supports training agents that can use external tools as part of their decision process.
In this notebook, the agent has access to the BioGRID database, which it can query using read-only SQL commands to retrieve biological interaction data. The model learns when and how to use tools based on rewards.
We'll fine-tune a model using GRPO (Group Relative Policy Optimization) via TRL. The agent will:
We'll start by installing TRL, which automatically includes the main dependencies like Transformers.
We'll also install trackio (for logging and monitoring training runs), vLLM (for efficient generation), and jmespath (needed for the tools capabilities).
!pip install -Uq "trl[vllm]" git+https://github.com/huggingface/transformers.git trackio jmespath
Log in to your Hugging Face account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your access token on your account settings page.
from huggingface_hub import notebook_login
notebook_login()
For this example, we will use the BioGRID database, a curated resource containing protein, genetic, and chemical interaction data. We've already compiled and uploaded it to the Hub at qgallouedec/biogrid. The dataset is loaded and converted into an sqlite database.
💡 We remove spaces in the column names to easen the model work. In real-world deployments, you may keep your original column names and rely on the agent to reason about them. Here, we simplify the schema to make training smoother.
import sqlite3
from datasets import load_dataset
# Load dataset
biogrid_dataset = load_dataset("qgallouedec/biogrid", split="train")
df = biogrid_dataset.to_pandas()
# Normalize column names: remove spaces, replace with underscores
df.columns = [c.replace(" ", "_") for c in df.columns]
# Save to SQLite
conn = sqlite3.connect("biogrid.db")
try:
df.to_sql("interactions", conn, if_exists="replace", index=False)
print(f"biogrid.db created. Rows stored: {len(df)}")
finally:
conn.close()
The training objective is to fine-tune a model to answer gene-related questions. The model should learn to use the database query tool to retrieve factual information when needed.
We'll define a formatting function for each sample, adding instructions about the database and how to call it. The model must answer with yes or no. Let's implement the format_example function.
import textwrap
def format_example(example):
question = example["question"]
preamble = textwrap.dedent("""\
You have access to the BioGRID SQLite database.
Use SQL queries to retrieve only the information needed to answer the question.
Genes may appear in the database in columns `Alt_IDs_Interactor_A` `Alt_IDs_Interactor_B`, `Aliases_Interactor_A` and `Aliases_Interactor_B`,
and each entry can contain multiple gene names or synonyms separated by '|', for example:
'entrez gene/locuslink:JNKK(gene name synonym)|entrez gene/locuslink:MAPKK4(gene name synonym)|...'
So a gene like 'JNKK' or 'MAPKK4' may appear inside one of these strings.
If the database schema is unclear or you are unsure about column names:
- First inspect the schema with `PRAGMA table_info(interactions);`
- Or preview a few rows with `SELECT * FROM interactions LIMIT 1;`
Otherwise, directly query the required data.
Final answer must be enclosed in stars, e.g. *Yes* or *No*.
Facts:
- The NCBI Taxonomy identifier for humans is taxid:9606.
""")
content = f"{preamble}\nQuestion: {question}"
prompt = [{"role": "user", "content": content}]
return {"prompt": prompt}
Now, let's load the database and call the previous function.
For simplicity, we will only use questions that start with “Does the gene…”.
In a real use case, the full dataset can be used.
The QA dataset is available on the Hub.
dataset = load_dataset("qgallouedec/biogrid_qa", split="train")
dataset = dataset.filter(
lambda example: example["question"].startswith("Does the gene ")
) # keep only simple questions for example
dataset = dataset.map(format_example, remove_columns=["question"])
train_dataset = dataset
eval_dataset = None # No eval by default, can be added if needed
The query_biogrid function is the tool the model will use to query the database and retrieve factual information.
Each tool must be a standard Python function with type-hinted arguments and return types, and a Google-style docstring describing its purpose, parameters, and return value.
from contextlib import contextmanager
import signal
@contextmanager
def timeout(seconds):
"""Context manager that raises TimeoutError if execution exceeds time limit."""
def timeout_handler(signum, frame):
raise TimeoutError(f"Operation timed out after {seconds} seconds")
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(seconds)
try:
yield
finally:
signal.alarm(0)
def query_biogrid(sql_command: str) -> list[tuple]:
"""
Execute a read-only SQL command on the BioGRID database.
BioGRID is a curated biological database that compiles protein, genetic, and chemical interactions from multiple organisms. It provides researchers with experimentally verified interaction data to support studies in systems biology and functional genomics.
Args:
sql_command: The SQL command to execute.
Returns:
A list of tuples containing the query results.
"""
with timeout(5):
conn = sqlite3.connect("file:biogrid.db?mode=ro", uri=True)
cursor = conn.cursor()
try:
cursor.execute(sql_command)
results = cursor.fetchall()
finally:
conn.close()
return results
To guide the agent during training, we define a few simple reward functions:
query_reward: evaluates the model’s query strategy — penalizes more than two queries, penalizes generic database scans, and rewards use of WHERE and evidence supporting the final answer.correctness_reward: rewards Yes/No predictions that match the expected answer.structure_reward: rewards a proper assistant structure (tool call → response → optional explanation).Each function returns a list of floats used by the GRPOTrainer during optimization.
Combined, they encourage effective tool use and factual answers.
import re
def query_reward(completions, answer, **kwargs):
"""
Reward query strategy:
- Penalize more than 2 queries
- Penalize generic queries (LIMIT 1 / PRAGMA)
- Reward usage of WHERE
- Reward evidence supporting the final answer
"""
rewards = []
for completion, ans in zip(completions, answer, strict=False):
reward = 0.0
sql_queries = []
tool_results = []
# collect all SQL queries and tool results
for turn in completion:
if turn.get("tool_calls"):
for call in turn["tool_calls"]:
sql = call["function"]["arguments"].get("sql_command", "").lower()
sql_queries.append(sql)
if turn.get("role") == "tool" and turn.get("content"):
tool_results.append(turn["content"])
# --- penalize too many queries ---
if len(sql_queries) > 3:
reward -= 1.5
# --- check query quality ---
where_count = 0
for q in sql_queries:
if "limit 1" in q:
reward -= 1.0
if " where " not in q:
reward -= 0.5
else:
where_count += 1
reward += min(where_count, 3) * 0.4 # small bonus for WHERE usage
# --- evidence check: do queries support the answer? ---
combined_results = []
error_detected = False
for res in tool_results:
if isinstance(res, dict) and "error" in res:
error_detected = True
elif isinstance(res, list):
combined_results.extend(res)
# if error detected, penalize heavily
if error_detected:
reward -= 2.0
elif len(sql_queries) == 0:
reward -= 1.5
else:
has_hits = len(combined_results) > 0
correct_answer = ans.lower()
if (has_hits and correct_answer == "yes") or (not has_hits and correct_answer == "no"):
reward += 2.0
else:
reward -= 1.5
rewards.append(reward)
return rewards
def correctness_reward(completions, answer, **kwargs):
"""
Reward Yes/No correctness.
Model must provide final answer enclosed in stars — *yes* or *no*.
Does not reward informal yes/no buried in text.
"""
rewards = []
for completion, ans in zip(completions, answer, strict=False):
raw = completion[-1]["content"].lower()
# detect form *yes* or *no*
match = re.search(r"\*(yes|no)\*", raw)
guess = match.group(1) if match else None
reward = 0.0
if guess is None:
reward -= 0.5 # invalid format
elif guess == ans.lower():
reward += 0.6 # correct under required format
else:
reward -= 1.0 # wrong answer
rewards.append(reward)
return rewards
def structure_reward(completions, **kwargs):
"""
Reward proper assistant structure.
Encourages a logical sequence: tool call + response + optional extra content.
"""
rewards = []
for completion in completions:
has_call = False
has_response = False
has_other = False
for turn in completion:
role = turn.get("role")
if role == "assistant" and turn.get("tool_calls"):
has_call = True
elif role == "tool":
has_response = True
else:
content = turn.get("content")
if content and content.strip() not in ["", "<think>"]:
has_other = True
# Reward sequences
if has_call and has_response:
if has_other:
reward = 0.1
else:
reward = 0.05 # still positive even without extra text
elif has_call and not has_response:
reward = -0.15
else:
reward = 0.0 # neutral if no call
rewards.append(reward)
return rewards
Next, we define the GRPOConfig, which controls the main training parameters.
This configuration specifies how the model interacts with vLLM, manages memory, and logs results.
from trl import GRPOConfig
output_dir = "grpo_biogrid_qwen_3g-1.7b"
grpo_config = GRPOConfig(
# Training schedule / optimization
max_steps=400, # Max number of training steps
chat_template_kwargs = {"enable_thinking": False}, # Disable thinking to reduce token generation
# GRPO configuration
max_completion_length = 1024, # Maximum tokens generated per model response
# vLLM configuration
use_vllm = True, # Enable vLLM for faster inference during rollouts
vllm_mode = "colocate", # Run vLLM in colocate mode (same process as training)
vllm_enable_sleep_mode=False,
# Logging / reporting
output_dir = output_dir, # Directory for checkpoints and logs
report_to="trackio", # Experiment tracking tool (integrates with HF Spaces)
trackio_space_id = output_dir, # HF Space where experiment tracking will be saved
save_steps = 10, # Interval for saving checkpoints
log_completions = True,
# Hub integration
push_to_hub = True, # Set True to automatically push model to Hugging Face Hub
)
GRPOTrainer and Start TrainingNext, we initialize the GRPOTrainer, which handles the full reinforcement learning loop.
It receives the model name, reward functions, tool(s), and dataset defined earlier.
Finally, we call trainer.train() to begin fine-tuning, allowing the model to learn how to query the database effectively through iterative feedback.
from trl import GRPOTrainer
model_name="Qwen/Qwen3-1.7B"
trainer = GRPOTrainer(
model=model_name,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tools=[query_biogrid],
reward_funcs=[correctness_reward, structure_reward, query_reward],
args=grpo_config,
)
Show memory stats before training
import torch
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")
And train!
trainer_stats = trainer.train()
Show memory stats after training
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.")
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")
Let's save the trained model.
trainer.save_model(output_dir)
trainer.push_to_hub()
smolagentsAfter fine-tuning the model with GRPO (TRL) for tool calling, we can test it at inference time using smolagents, a lightweight library for running multi-step agents.
smolagents handles the agent loop for us:
Note
Using an agent framework is optional. The fine-tuned model can also be used directly withtransformersby manually controlling the inference loop and executing the tools outside the model. Agent frameworks are especially useful when the number of steps or tool calls is not fixed.
We start by installing the required package:
!pip install git+https://github.com/huggingface/smolagents.git
We will use the CodeAgent class from smolagents to instantiate our agent.
First, we need to define the tool the agent can use. This is done using the @tool decorator.
As shown below, the tool definition is exactly the same as the one used during GRPO training with TRL. This consistency is important: the model was trained to emit calls following this schema, and at inference time the agent simply executes the corresponding Python function.
from smolagents import tool
@tool
def query_biogrid(sql_command: str) -> list[tuple]:
"""
Execute a read-only SQL query on the BioGRID database.
BioGRID is a curated biological database that compiles protein, genetic,
and chemical interactions from multiple organisms.
Args:
sql_command: A read-only SQL query to execute.
Returns:
A list of tuples containing the query results.
"""
with timeout(5):
conn = sqlite3.connect(
"file:biogrid.db?mode=ro",
uri=True,
)
cursor = conn.cursor()
try:
cursor.execute(sql_command)
results = cursor.fetchall()
finally:
conn.close()
return results
Now we can instantiate the agent using our fine-tuned model and the database tool defined above.
We wrap the model with TransformersModel and pass both the model and the tool when creating the CodeAgent.
from smolagents import TransformersModel, CodeAgent
model = TransformersModel(model_id="sergiopaniego/grpo_biogrid_qwen_3g-1.7b", apply_chat_template_kwargs={"enable_thinking": False})
# Create an agent with query_biogrid as tool
agent = CodeAgent(tools=[query_biogrid], model=model)
Finally, we run the agent by passing the full prompt (including the instruction preamble and the question), exactly as it was used during training. This ensures the agent operates under the same context and assumptions learned with GRPO, allowing it to correctly decide when to query the database and how to format the final answer.
result = agent.run(train_dataset[0]['prompt'][0]['content'])
print(result)