Back to Trl

Fine-Tune FunctionGemma using Hugging Face TRL and OpenEnv

examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb

1.3.023.5 KB
Original Source

Fine-Tune FunctionGemma using Hugging Face TRL and OpenEnv

This guide describes the process of fine-tuning FunctionGemma by Google DeepMind in the BrowserGym environment provided by OpenEnv, using Hugging Face TRL. The steps covered include:

  • What is GRPO and OpenEnv
  • Setup dependencies for training
  • Initialize the OpenEnv's BrowserGym environment
  • Create rollout function with helpers
  • Define the reward functions
  • Load the custom dataset
  • Fine tune using TRL and the GRPOTrainer
  • Load the fine-tuned model and run inference

Note: The guide is designed to run on Google Colaboratory with access to an NVIDIA A100 GPU (40GB) using FunctionGemma. The workflow can be adapted to other GPU configurations, models, or environments.

What is GRPO and OpenEnv

Group Relative Policy Optimization (GRPO) is a post-training method widely used for efficiently fine-tuning large language models. GRPO leverages reward functions to guide learning, enabling models to optimize task-specific behaviors without retraining the entire network.

OpenEnv provides a standard interface for interacting with agentic execution environments using simple Gymnasium-style APIs, such as step(), reset(), and state(). These APIs facilitate reinforcement learning training loops by allowing models to interact with environments in a structured manner. OpenEnv also offers tools for environment creators to build isolated, secure, and deployable environments that can be shared via common protocols like HTTP or packaged in Docker.

The combination of GRPO and OpenEnv enables efficient fine-tuning of models in controlled, interactive tasks while minimizing resource requirements.

Setup dependencies for training

Install the required libraries, including Hugging Face TRL for fine-tuning and OpenEnv for reinforcement learning environments.

python
!pip install -Uq trl[vllm] git+https://huggingface.co/spaces/openenv/browsergym_env liger-kernel trackio

A valid Hugging Face token is required to save the fine-tuned model. In Google Colab, the token can be securely accessed through Colab secrets. Otherwise, it can be provided directly in the login method. Ensure the token has write permissions to allow uploading the model to the Hugging Face Hub during training.

python
from google.colab import userdata
from huggingface_hub import login

# Login into Hugging Face Hub
hf_token = userdata.get('HF_TOKEN') # If you are running inside a Google Colab
login(hf_token)

Initialize the OpenEnv's BrowserGym environment

External environments can guide the fine-tuning of LLMs for function calling by providing interactive feedback that enhances performance on task-specific behaviors.

BrowserGym is a unified framework for web-based agent tasks, offering multiple benchmarks through a Gymnasium-compatible API. It enables training on simple synthetic tasks with MiniWoB++ and evaluation on more complex, realistic tasks with WebArena, VisualWebArena, or WorkArena. This setup supports iterative training and assessment of web agents without requiring extensive infrastructure.

BrowserGym supports both LLM and VLM training by providing visual information, including screenshots and DOM data, which can be utilized depending on the model type. This guide focuses on a simple web-based task called "click-test", which is part of the MiniWoB++ benchmark of synthetic web tasks. Environments can be run locally, in Docker containers, or accessed remotely via the Hugging Face Hub. For this example, the remote environment openenv/browsergym_env will be used.

Note: Hosted environments on the Hub currently have limited concurrency. For higher reliability or parallel runs, duplicating the Space to your own account is strongly recommended.

python
from browsergym_env import BrowserGymEnv
space_url = "https://openenv-browsergym-env.hf.space"

client = BrowserGymEnv(base_url=space_url)

Create rollout function with helpers

The rollout function defines how the agent interacts with the environment during GRPO training. It generates model outputs, collects feedback in the form of rewards, and returns the information required for optimization.

In this setup:

  • The function is invoked automatically by the GRPOTrainer (introduced later), which orchestrates the training loop and handles policy updates.
  • It uses the trainer's generate_rollout_completions() method for efficient output generation. This leverages vLLM, a high-performance inference engine for large language models, and is integrated within TRL to streamline rollout generation and reward collection during fine-tuning.
  • Each rollout represents a complete interaction loop, where the model acts, receives feedback from the environment, and updates based on reward signals.

Rewards capture various aspects of the agent's performance. Helper functions, such as rollout_once, manage individual episodes, keeping the main rollout_func clean, modular, and reusable.

This modular structure allows GRPO to efficiently sample, evaluate, and refine the model's behavior through reinforcement learning.

Before executing rollouts, a system prompt is defined to instruct the model on how to interact with the environment. This prompt specifies the available BrowserGym actions (such as click, fill, send_keys, and scroll), describes the page structure, and enforces that the model responds with exactly one action per step. It ensures consistent and structured interactions, guiding the model to complete tasks effectively without providing extra explanations or multiple actions.

python
# @title System prompt (click to expand)
SYSTEM_PROMPT = """You control a web browser through BrowserGym actions.
You must complete the given web task by interacting with the page.

Available actions:
- noop() - Do nothing
- click(bid) - Click element with BrowserGym ID (the number in brackets)
- fill(bid, text) - Fill input field with text
- send_keys(text) - Send keyboard input
- scroll(direction) - Scroll up/down

The page structure shows elements as: [bid] element_type 'element_text'
For example: [13] button 'Click Me!' means bid='13'

Reply with exactly ONE action on a single line, e.g.:
click('13')
fill('42', 'hello world')
noop()

Do not include explanations or multiple actions."""

The rollout_func orchestrates the interaction between the model and the remote BrowserGym environment. For each prompt in the batch, it executes a complete episode using the rollout_once function, collecting model outputs and rewards for GRPO optimization.

The parameter max_steps defines the maximum number of steps the model can take within a single episode. This limits the length of the interaction loop, ensuring that episodes terminate even if the task is not completed, and helps maintain efficient training.

During each episode, the function tracks prompt and completion IDs, log probabilities, and both step-wise and final rewards, returning them in a structured format for the trainer to perform policy updates.

python
from trl import GRPOTrainer

max_steps=10

def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
    episode_prompt_ids: list[list[int]] = []
    episode_completion_ids: list[list[int]] = []
    episode_logprobs: list[list[float]] = []
    completion_rewards: list[float] = []

    print(f"\n[DEBUG] rollout_func called with {len(prompts)} prompts (LLM mode, text-only)")

    for i, prompt_text in enumerate(prompts):
        print(f"[DEBUG] Processing prompt {i + 1}/{len(prompts)}")
        episode = rollout_once(
            trainer=trainer,
            env=client,
            tokenizer=trainer.processing_class,
            dataset_prompt=prompt_text,
            max_steps=max_steps,
        )
        episode_prompt_ids.append(episode["prompt_ids"])
        episode_completion_ids.append(episode["completion_ids"])
        episode_logprobs.append(episode["logprobs"])
        completion_rewards.append(episode["completion_reward"])

    return {
        "prompt_ids": episode_prompt_ids,
        "completion_ids": episode_completion_ids,
        "logprobs": episode_logprobs,
        "completion_reward": completion_rewards,
    }

Define rollout_once

The rollout_once function runs one complete interaction loop between the model and the BrowserGym environment using the trainer's generation method.
It executes a single episode, from generating an action to receiving feedback and computing rewards.

Here's the step-by-step breakdown:

  1. Environment reset: Start a new BrowserGym session and initialize the observation.
  2. Prompt construction: Combine the system prompt, environment observation (text-only via the accessibility tree), and any relevant errors or state information to form the model input.
  3. Generation: Use trl.experimental.openenv.generate_rollout_completions() to produce the model's action efficiently with vLLM.
  4. Action parsing and execution: Interpret the model's output and execute the corresponding BrowserGym action (e.g., click, fill, scroll).
  5. Reward calculation: Track step-wise rewards provided by the environment and compute completion rewards based on task success or failure.
  6. Return structured rollout data: Includes prompt/completion IDs, log probabilities, step rewards, and the final reward for the episode.

This modular design allows each episode to be processed independently while providing rich feedback for the GRPO training loop, supporting both task completion and intermediate reward shaping.

python
from trl.experimental.openenv import generate_rollout_completions
from browsergym_env import BrowserGymAction
from transformers import AutoTokenizer

def rollout_once(
    trainer: GRPOTrainer,
    env: BrowserGymEnv,
    tokenizer: AutoTokenizer,
    dataset_prompt: str,
    max_steps: int,
) -> dict[str, list]:
    """Run one episode and collect training data (text-only, no screenshots)."""
    result = env.reset()
    observation = result.observation

    prompt_ids: list[int] = []
    completion_ids: list[int] = []
    logprobs: list[float] = []
    step_rewards: list[float] = []
    completion_rewards: list[float] = []

    for step_num in range(max_steps):
        if result.done:
            break

        # Create prompt from observation (text-only using accessibility tree)
        goal = observation.goal or dataset_prompt
        axtree = observation.axtree_txt or ""
        error = observation.error if observation.last_action_error else ""

        user_prompt = make_user_prompt(goal, step_num, axtree, error)
        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": user_prompt},
        ]
        prompt_text = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=False,
        )

        # Generate action with vLLM
        rollout_outputs = generate_rollout_completions(trainer, [prompt_text])[0]
        prompt_ids.extend(rollout_outputs["prompt_ids"])
        completion_ids.extend(rollout_outputs["completion_ids"])
        logprobs.extend(rollout_outputs["logprobs"])

        completion_text = rollout_outputs.get("text") or tokenizer.decode(
            rollout_outputs["completion_ids"], skip_special_tokens=True
        )

        # Parse and execute action
        action_str = parse_action(completion_text)

        print(f"Step {step_num + 1}: {action_str}")

        # Take action in environment
        result = env.step(BrowserGymAction(action_str=action_str))
        observation = result.observation

        # Track rewards
        step_reward = float(result.reward or 0.0)
        step_rewards.append(step_reward)

        # Reward shaping: success is most important
        if result.done and step_reward > 0:
            completion_rewards.append(1.0)  # Task completed successfully
        elif result.done and step_reward == 0:
            completion_rewards.append(0.0)  # Task failed
        else:
            completion_rewards.append(step_reward)  # Intermediate reward

    # Final reward is based on task completion
    final_reward = completion_rewards[-1] if completion_rewards else 0.0

    return {
        "prompt_ids": prompt_ids,
        "completion_ids": completion_ids,
        "logprobs": logprobs,
        "step_rewards": step_rewards,
        "completion_reward": final_reward,
    }

Helper functions

Supporting utilities used in rollout_once:

  • make_user_prompt: builds the user prompt combining the base text and previous game messages.
  • parse_action: parses BrowserGym action from model response
python
# @title Helpers (click to expand)
def make_user_prompt(goal: str, step_num: int, axtree: str, error: str = "") -> str:
    """Create user prompt from observation."""
    prompt_parts = [f"Step {step_num + 1}"]

    if goal:
        prompt_parts.append(f"Goal: {goal}")

    if error:
        prompt_parts.append(f"Previous action error: {error}")

    # Include accessibility tree (truncated for context)
    if axtree:
        max_len = 2000
        axtree_truncated = axtree[:max_len] + "..." if len(axtree) > max_len else axtree
        prompt_parts.append(f"Page structure:\n{axtree_truncated}")

    prompt_parts.append("What action do you take?")

    return "\n\n".join(prompt_parts)


def parse_action(response_text: str) -> str:
    """Parse BrowserGym action from model response."""
    # Extract first line that looks like an action
    for line in response_text.strip().split("\n"):
        line = line.strip()
        if "(" in line and ")" in line:
            return line

    # Fallback to noop if no valid action found
    return "noop()"

Define the reward functions

Reward functions quantify the model's performance in the environment and guide the GRPO optimization process.

In this setup, the reward_completion function assigns rewards based on task completion. It extracts the final reward for each episode, which indicates whether the agent successfully completed the task. If no reward information is available, it defaults to zero.

This modular approach allows additional reward functions to be added easily, enabling more granular feedback such as intermediate progress, efficiency, or correctness of actions, depending on the task requirements.

python
def reward_completion(completions: list[str], **kwargs) -> list[float]:
    """Reward for task completion."""
    rewards = kwargs.get("completion_reward") if kwargs else None
    if rewards is None:
        return [0.0 for _ in completions]
    return [float(r) for r in rewards]

Load the custom dataset

The dataset is constructed with repeated prompts to control the total number of training episodes.

Each entry in the dataset triggers a single rollout episode during training. The dataset_prompt provides the initial instruction to the model at the start of each episode, ensuring consistent guidance for task execution.

python
from datasets import Dataset

dataset_prompt = "Complete the web task successfully."
dataset_size = 1000

dataset = Dataset.from_dict({"prompt": [dataset_prompt] * dataset_size})

Fine-tune using TRL and the GRPOTrainer

The next step is to define the GRPOConfig, which sets all key training parameters.

This configuration determines how the model interacts with vLLM, handles memory and computation, and records training metrics and logs for monitoring the fine-tuning process.

python
from trl import GRPOConfig
output_dir = "browsergym-grpo-functiongemma-270m-it"

grpo_config = GRPOConfig(
    # num_train_epochs=1,                                     # Number of times to iterate over the full dataset (use for full training runs)
    max_steps=100,                                            # Number of dataset passes (for shorter runs/testing). For full trainings, use `num_train_epochs` instead
    learning_rate=5e-6,                                       # Learning rate for the optimizer
    warmup_steps=10,                                          # Number of steps to linearly increase learning rate at the start of training

    per_device_train_batch_size=1,                            # Number of samples per device per step
    num_generations=4,                                        # Number of completions to generate per prompt
    generation_batch_size=4,                                  # Batch size used during generation (must be divisible by num_generations)
    max_completion_length=32,                                 # Maximum length of generated completions

    use_vllm=True,                                            # Use vLLM engine for fast inference
    vllm_mode="colocate",                                     # vLLM mode: "colocate" runs generation on the same GPU as training
    vllm_gpu_memory_utilization=0.1,                          # Fraction of GPU memory allocated to vLLM

    output_dir=str(output_dir),                               # Directory where checkpoints, logs, and outputs will be saved
    logging_steps=1,                                          # Log metrics every N steps
    report_to="trackio",                                      # Logging/reporting platform (e.g., "trackio")
    trackio_space_id=output_dir,                              # HF Space where the experiment tracking will be saved
    push_to_hub=True,                                         # Optionally push trained model to Hugging Face Hub

    use_liger_kernel=True,                                    # Enable Liger kernel optimizations for faster training
)

The next step is to initialize the GRPOTrainer, which manages the complete reinforcement learning loop.

It receives the model name, reward functions, rollout function, and dataset defined earlier. From the model name, the trainer automatically initializes the model and tokenizer. It then coordinates interactions between the model and the environment, applies the defined reward signals, and updates the policy during training.

Finally, calling trainer.train() starts the fine-tuning process, enabling the model to progressively improve its performance through iterative interaction and reinforcement learning.

Note: The training pipeline uses approximately 10.6 GB of GPU VRAM and can be adapted to different hardware configurations.

python
model_name = "google/functiongemma-270m-it"
python
trainer = GRPOTrainer(
    model=model_name,
    reward_funcs=[reward_completion],
    train_dataset=dataset,
    args=grpo_config,
    rollout_func=rollout_func,
)
python
trainer_stats = trainer.train()

In this step, the fine-tuned model is saved locally and uploaded to the Hugging Face Hub using the configured account credentials.

python
trainer.save_model(output_dir)
trainer.push_to_hub()

Load the Fine-Tuned Model and Run Inference

The fine-tuned model is loaded to perform inference and evaluate its behavior on the target task.
In this case, the model is tested within the BrowserGym environment using OpenEnv, focusing on the click task from the MiniWoB++ benchmark, which is included among the available BrowserGym tasks.

python
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "sergiopaniego/browsergym-grpo-functiongemma-270m-it" # Replace with your HF username or organization

fine_tuned_model = AutoModelForCausalLM.from_pretrained(model_name, dtype="float32", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

With the fine-tuned model loaded, testing can be conducted on the BrowserGym environment. To streamline evaluation, a reusable function is defined that executes multiple rounds of the task. This function follows the same interaction logic as used during training, generating model actions from observations, executing them in the environment, and printing the results step by step.

python
def test_click_in_browsergym(env, model, tokenizer):
    result = env.reset()
    observation = result.observation

    for step_num in range(max_steps):
        if result.done:
            break

        # Create prompt from observation (text-only using accessibility tree)
        goal = observation.goal or dataset_prompt
        axtree = observation.axtree_txt or ""
        error = observation.error if observation.last_action_error else ""

        user_prompt = make_user_prompt(goal, step_num, axtree, error)
        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": user_prompt},
        ]
        prompt_text = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=False,
        )

        # Generate action
        prompt_text = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=False,
            enable_thinking=False,
        )

        model_inputs = tokenizer([prompt_text], return_tensors="pt").to(model.device)

        generated_ids = model.generate(
            **model_inputs,
            max_new_tokens=512
        )
        output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]

        # Decode and extract model response
        generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)

        action_str = parse_action(generated_text)
        print(f"Step {step_num + 1}: {action_str}")

        # Take action in environment
        result = env.step(BrowserGymAction(action_str=action_str))
        observation = result.observation

The test_click_in_browsergym function is called to run a full evaluation of the fine-tuned model on the BrowserGym click task.

The environment client is safely closed after testing using a try/finally block, ensuring that all resources are released even if an error occurs during execution.

python
try:
    test_click_in_browsergym(client, fine_tuned_model, tokenizer)
finally:
    client.close()

Summary and Next Steps

This tutorial demonstrated how to fine-tune a FunctionGemma model using TRL, GRPO, and the BrowserGym environment from OpenEnv. Check out the following docs next: