Back to Trl

OpenEnv Sudoku with GRPO using TRL

examples/notebooks/openenv_sudoku_grpo.ipynb

1.3.026.5 KB
Original Source

OpenEnv Sudoku with GRPO using TRL

With Transformers Reinforcement Learning (TRL), you can train a model that learns to play Sudoku, a logic-based number puzzle, through interaction and reinforcement.

An agentic environment is a setting where a model can take actions, observe outcomes, and adjust its behavior based on feedback, similar to how humans learn from trial and error. In this case, the agent interacts with the Sudoku environment through the OpenEnv framework, which standardizes multi-agent and RL-style text environments.

Sudoku is a classic logic-based puzzle where the objective is to fill a 9x9 grid so that each row, column, and 3x3 subgrid contains all digits from 1 to 9 exactly once. This structured yet challenging setup makes Sudoku an excellent benchmark for reasoning and decision-making tasks.

We'll fine-tune a model using GRPO (Group Relative Policy Optimization) via TRL. Using environment_factory, the trainer automatically handles:

  1. Creating environment instances for each rollout.
  2. Generating model completions and parsing tool calls.
  3. Stepping through the environment with the model's actions.
  4. Collecting rewards and managing the interaction loop.

This means you only need to define the environment class and reward functions -- the trainer takes care of the rest.

Install dependencies

We'll start by installing TRL (with vLLM support), the OpenEnv Sudoku environment, and trackio for logging.

python
!pip install -Uq trl[vllm] git+https://huggingface.co/spaces/openenv/sudoku trackio

Log in to Hugging Face

Log in to your Hugging Face account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your access token on your account settings page.

python
from huggingface_hub import notebook_login

notebook_login()

Define the system prompt

This prompt instructs the model on how to play Sudoku. It includes the game rules, board reading format, strategic approaches, and importantly, tells the model to use the place tool to submit moves. The environment_factory pattern uses tool calling to interact with the environment, so the model needs to know which tool to call.

python
system_prompt = """You are an expert Sudoku player with deep knowledge of logical deduction strategies and number placement techniques.

## GAME RULES

1. The puzzle is a 9x9 grid divided into nine 3x3 subgrids (boxes)
2. Some cells are pre-filled with numbers 1-9
3. You must fill in the empty cells (shown as '.') with numbers 1-9
4. Each row must contain numbers 1-9 without repetition
5. Each column must contain numbers 1-9 without repetition
6. Each 3x3 subgrid must contain numbers 1-9 without repetition
7. You cannot overwrite pre-filled cells
8. Invalid moves result in penalties (-1 reward)

## HOW TO PLAY

Use the `place` tool to make a move. The tool takes three arguments:
- `row`: Row number (1-9)
- `col`: Column number (1-9)
- `number`: The digit to place (1-9)

## STRATEGIC APPROACH

Do not repeat the same move twice.

### Basic Strategies
- **Naked Singles**: If a cell has only one possible candidate, fill it in immediately.
- **Hidden Singles**: If a number can only go in one cell within a row, column, or box, place it there.
- **Scanning**: Look at each row, column, and box to find where specific numbers can go.

### Solving Process
1. Start by scanning the entire grid to identify easy fills (cells with few candidates)
2. Look for rows, columns, or boxes with many numbers already placed
3. Fill all naked singles first
4. Then look for hidden singles in each row, column, and box

### Common Pitfalls to Avoid
- Don't guess randomly - Sudoku is pure logic
- Don't overlook any constraint (row, column, or box)
- Don't try to overwrite pre-filled cells
- Don't place invalid numbers (must be 1-9)
- Don't use invalid coordinates (must be 1-9)
- Don't repeat a move that was already made

## BOARD READING

The board is displayed as a 9x9 grid:
- Numbers 1-9 are pre-filled or already placed
- Empty cells are shown as '.'
- Rows are labeled R1-R9 (top to bottom)
- Columns are labeled C1-C9 (left to right)

## IMPORTANT CONSTRAINTS

- Coordinates are 1-indexed (1-9 for both row and column)
- Numbers must be 1-9
- One move per response
- Must be a valid move (no rule violations)
- Never repeat a previous move
"""

Define the environment

The SudokuEnv class wraps the OpenEnv TextArena Sudoku environment into the interface expected by environment_factory.

When you pass environment_factory=SudokuEnv to the trainer, it will:

  1. Create a new SudokuEnv() instance for each rollout episode.
  2. Call reset() to start a new game (returns the initial board state).
  3. Automatically generate model completions, parse tool calls, and invoke the place(row, col, number) method.
  4. Repeat until the environment signals done=True or the max completion length is reached.

The environment tracks multiple reward signals as properties:

  • correct_reward: Did the puzzle get solved?
  • valid_move_reward: Average rate of valid moves.
  • empty_cell_reward: Did the model target empty cells?
  • repetition_reward: Penalty for repeating moves.
  • progress_reward: How many cells were filled (normalized 0-1).

For this example, we connect to the hosted environment at openenv/sudoku. For production use, we recommend duplicating the Space to your own account or running it locally via Docker, as the hosted versions have limited concurrency.

For more information, refer to the TRL-OpenEnv documentation.

python
# @title SudokuEnv class (click to expand)
from collections import defaultdict

from textarena_env import TextArenaAction, TextArenaEnv


def _is_valid_board_state(board_str: str) -> bool:
    return "R1" in board_str and "R9" in board_str and "|" in board_str


def _parse_board(board_str: str) -> list[list[int]]:
    grid = [[0] * 9 for _ in range(9)]
    if not _is_valid_board_state(board_str):
        return grid
    for line in board_str.split("\n"):
        line_stripped = line.strip()
        if line_stripped and line_stripped[0] == "R" and len(line_stripped) > 1 and line_stripped[1].isdigit():
            row = int(line_stripped[1]) - 1
            cell_part = line_stripped[2:]
            col = 0
            for char in cell_part:
                if char == ".":
                    grid[row][col] = 0
                    col += 1
                elif char.isdigit():
                    grid[row][col] = int(char)
                    col += 1
    return grid


def _count_filled_cells(board_str: str) -> int:
    if not _is_valid_board_state(board_str):
        return 0
    grid = _parse_board(board_str)
    return sum(1 for row in grid for cell in row if cell != 0)


def _get_valid_numbers(grid: list[list[int]], row: int, col: int) -> set[int]:
    if grid[row][col] != 0:
        return set()
    used = set()
    for c in range(9):
        if grid[row][c] != 0:
            used.add(grid[row][c])
    for r in range(9):
        if grid[r][col] != 0:
            used.add(grid[r][col])
    box_row, box_col = 3 * (row // 3), 3 * (col // 3)
    for r in range(box_row, box_row + 3):
        for c in range(box_col, box_col + 3):
            if grid[r][c] != 0:
                used.add(grid[r][c])
    return set(range(1, 10)) - used


def _extract_empty_cells_with_candidates(board_str: str, sort_by_difficulty: bool = True):
    grid = _parse_board(board_str)
    cells_with_candidates = []
    for row in range(9):
        for col in range(9):
            if grid[row][col] == 0:
                candidates = _get_valid_numbers(grid, row, col)
                cells_with_candidates.append((row + 1, col + 1, candidates))
    if sort_by_difficulty:
        cells_with_candidates.sort(key=lambda x: len(x[2]))
    return cells_with_candidates


def _extract_empty_cells(board_str: str) -> list[tuple[int, int]]:
    empty_cells = []
    if not _is_valid_board_state(board_str):
        return empty_cells
    for line in board_str.split("\n"):
        line_stripped = line.strip()
        if line_stripped and line_stripped[0] == "R" and len(line_stripped) > 1 and line_stripped[1].isdigit():
            row = int(line_stripped[1])
            cell_part = line_stripped[2:]
            col = 0
            for char in cell_part:
                if char == ".":
                    col += 1
                    empty_cells.append((row, col))
                elif char.isdigit():
                    col += 1
    return empty_cells


def _extract_board_only(text: str) -> str:
    if not text:
        return ""
    lines = text.split("\n")
    board_lines = []
    in_board = False
    for line in lines:
        stripped = line.strip()
        if stripped.startswith("C1") or (
            stripped and stripped[0] == "R" and len(stripped) > 1 and stripped[1].isdigit()
        ):
            in_board = True
        if in_board and (stripped.startswith("-") or stripped.startswith("R") or stripped.startswith("C1")):
            board_lines.append(line)
        elif (
            in_board
            and stripped
            and not stripped.startswith("-")
            and not (stripped[0] == "R" and len(stripped) > 1 and stripped[1].isdigit())
        ):
            break
    return "\n".join(board_lines) if board_lines else ""


def _make_hints(board, successful_moves, failed_moves, difficulty="easy"):
    """Generate hint text for the model."""
    parts = []
    all_tried = successful_moves + failed_moves
    if all_tried:
        parts.append(f"\nMOVES ALREADY TRIED (do not repeat): {', '.join(all_tried)}")
    if not board:
        return "\n".join(parts)
    if difficulty == "easy":
        cells = _extract_empty_cells_with_candidates(board, sort_by_difficulty=True)
        if cells:
            guaranteed = []
            other = []
            for r, c, candidates in cells[:10]:
                if len(candidates) == 1:
                    guaranteed.append(f"[{r} {c} {list(candidates)[0]}]")
                elif len(candidates) <= 3:
                    nums = ",".join(str(n) for n in sorted(candidates))
                    other.append(f"({r},{c})->{nums}")
            if guaranteed:
                parts.append(f"\nGUARANTEED MOVES: {', '.join(guaranteed[:5])}")
            if other:
                parts.append(f"Other options: {' | '.join(other[:5])}")
    elif difficulty == "medium":
        cells = _extract_empty_cells_with_candidates(board, sort_by_difficulty=False)
        if cells:
            cell_hints = []
            for r, c, candidates in cells[:10]:
                nums = ",".join(str(n) for n in sorted(candidates))
                cell_hints.append(f"({r},{c})->{nums}")
            parts.append(f"\nEmpty cells: {' | '.join(cell_hints)}")
    return "\n".join(parts)


class SudokuEnv:
    def __init__(self):
        self.client = TextArenaEnv(base_url="https://openenv-sudoku.hf.space")
        self.difficulty = "easy"
        self.max_turns = 100
        self._turn = 0
        self._move_counts = defaultdict(int)
        self._successful_moves = []
        self._failed_moves = []
        self._valid_move_scores = []
        self._empty_cell_scores = []
        self._correct_scores = []
        self._repetition_scores = []
        self._last_board_state = ""
        self._last_full_content = ""
        self._initial_filled = 0
        self._max_filled = 0
        self.done = False

    def reset(self, **kwargs) -> str | None:
        result = self.client.reset()
        observation = result.observation
        self.done = False
        self._turn = 0
        self._move_counts = defaultdict(int)
        self._successful_moves = []
        self._failed_moves = []
        self._valid_move_scores = []
        self._empty_cell_scores = []
        self._correct_scores = []
        self._repetition_scores = []
        self._last_board_state = ""
        self._initial_filled = 0
        self._max_filled = 0

        # Store full message content for diffing (messages are cumulative)
        self._last_full_content = observation.messages[0].content if observation.messages else ""

        for message in observation.messages:
            if message.content and _is_valid_board_state(message.content):
                self._last_board_state = message.content
                self._initial_filled = _count_filled_cells(self._last_board_state)
                self._max_filled = self._initial_filled
                break

        board = _extract_board_only(self._last_board_state) if self._last_board_state else "No board available."
        hints = _make_hints(self._last_board_state, [], [], self.difficulty)
        return f"Step 0. Progress: 0 cells filled.\n\nBoard:\n{board}{hints}"

    def place(self, row: int, col: int, number: int) -> str:
        """Place a number on the Sudoku board.

        Args:
            row: Row number (1-9).
            col: Column number (1-9).
            number: Number to place (1-9).

        Returns:
            The result of the move and updated board state.
        """
        if self.done:
            return "Game is over. No more moves allowed."

        self._turn += 1
        move = f"[{row} {col} {number}]"

        # Step environment
        result = self.client.step(TextArenaAction(message=move))
        observation = result.observation
        correct_score = float(result.reward or 0.0)
        self.done = result.done

        # Only check the NEW content for feedback (messages are cumulative)
        full_content = observation.messages[0].content if observation.messages else ""
        new_content = full_content[len(self._last_full_content):]
        self._last_full_content = full_content

        new_content_lower = new_content.lower()
        env_says_invalid = any(
            kw in new_content_lower for kw in ["invalid", "error", "cannot", "already", "violation", "lost"]
        )
        got_warning = "please resubmit" in new_content_lower or "avoid penalties" in new_content_lower

        # Also verify against our own board state: placing on a non-empty cell is always invalid
        if self._last_board_state:
            empty_cells = _extract_empty_cells(self._last_board_state)
            targets_empty = (row, col) in empty_cells
        else:
            empty_cells = []
            targets_empty = True  # Can't verify, assume valid

        is_valid = not env_says_invalid and targets_empty

        # Empty cell score: did the model target an empty cell?
        empty_cell_score = 1.0 if targets_empty else -1.0

        # Repetition tracking
        is_new_move = self._move_counts[move] == 0
        repetition_count = self._move_counts[move]
        self._move_counts[move] += 1
        repetition_score = -min(2 ** (repetition_count - 1), 10.0) if repetition_count > 0 else 0.0

        # Valid move score
        if is_valid and is_new_move:
            valid_move_score = 1.0
            self._successful_moves.append(move)
        elif got_warning:
            valid_move_score = -0.5
            self._failed_moves.append(move)
        else:
            valid_move_score = 0.0

        # Update board state from new content
        if is_valid and _is_valid_board_state(new_content):
            self._last_board_state = new_content
            current_filled = _count_filled_cells(self._last_board_state)
            if current_filled > self._max_filled:
                self._max_filled = current_filled

        self._valid_move_scores.append(valid_move_score)
        self._empty_cell_scores.append(empty_cell_score)
        self._correct_scores.append(correct_score)
        self._repetition_scores.append(repetition_score)

        # Enforce max turns
        if self._turn >= self.max_turns:
            self.done = True

        # Build response
        board = _extract_board_only(self._last_board_state) if self._last_board_state else "No board available."
        status = "valid" if is_valid else "invalid"
        cells_filled = len(self._successful_moves)
        progress = f"Step {self._turn}. Progress: {cells_filled} cells filled."
        hints = _make_hints(self._last_board_state, self._successful_moves, self._failed_moves, self.difficulty)

        if self.done:
            return f"Move {move}: {status}. Game over.\n{progress}\n\nFinal board:\n{board}"
        return f"Move {move}: {status}\n{progress}\n\nBoard:\n{board}{hints}"

    # ── Reward properties ──

    @property
    def correct_reward(self) -> float:
        return self._correct_scores[-1] if self._correct_scores else 0.0

    @property
    def valid_move_reward(self) -> float:
        return sum(self._valid_move_scores) / len(self._valid_move_scores) if self._valid_move_scores else 0.0

    @property
    def empty_cell_reward(self) -> float:
        return sum(self._empty_cell_scores) / len(self._empty_cell_scores) if self._empty_cell_scores else 0.0

    @property
    def repetition_reward(self) -> float:
        return sum(self._repetition_scores) / len(self._repetition_scores) if self._repetition_scores else 0.0

    @property
    def progress_reward(self) -> float:
        remaining_to_fill = 81 - self._initial_filled
        if remaining_to_fill > 0:
            return (self._max_filled - self._initial_filled) / remaining_to_fill
        return 1.0

Define the reward functions

The reward functions receive the list of environment instances after each episode completes. Since the SudokuEnv tracks multiple reward signals as properties, we simply read them out.

Each reward function captures a different aspect of play quality:

  • empty_cell_reward: Did the model target empty cells (vs. trying to overwrite filled ones)?
  • valid_move_reward: Were the moves accepted by the environment?
  • repetition_reward: Penalty for repeating the same move.
  • progress_reward: How much of the puzzle was filled (0 to 1).
  • correct_reward: Did the model solve the puzzle completely?
python
def reward_empty_cell(environments, **kwargs) -> list[float]:
    """Reward for targeting empty cells."""
    return [env.empty_cell_reward for env in environments]


def reward_valid_moves(environments, **kwargs) -> list[float]:
    """Reward for making valid moves."""
    return [env.valid_move_reward for env in environments]


def reward_repetition(environments, **kwargs) -> list[float]:
    """Penalty for repeating moves."""
    return [env.repetition_reward for env in environments]


def reward_progress(environments, **kwargs) -> list[float]:
    """Reward for filling more cells in the board."""
    return [env.progress_reward for env in environments]


def reward_correct(environments, **kwargs) -> list[float]:
    """Reward for solving the puzzle."""
    return [env.correct_reward for env in environments]

Create the dataset

We create a dataset with repeated prompts to control the number of training episodes. Each entry triggers one rollout episode during training. The prompt is formatted as a chat message with the system prompt.

python
from datasets import Dataset

dataset_size = 3000
dataset = Dataset.from_dict({
    "prompt": [[
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": "Play Sudoku like an expert."},
    ]] * dataset_size
})

Set GRPO Config

Next, we define the GRPOConfig, which controls all key training parameters. This configuration specifies how the model interacts with vLLM, manages memory, and logs results.

Note the chat_template_kwargs={"enable_thinking": False} parameter -- this disables Qwen3's thinking mode so the model responds directly with tool calls instead of generating internal reasoning tokens first.

python
from trl import GRPOConfig

model_name = "Qwen/Qwen3-1.7B"
output_dir = "sudoku-grpo-Qwen3-1.7B"

grpo_config = GRPOConfig(
    # Training schedule / optimization
    num_train_epochs=1,
    learning_rate=5e-6,
    gradient_accumulation_steps=64,
    per_device_train_batch_size=1,
    warmup_steps=20,
    optim="adamw_torch",
    max_grad_norm=1.0,

    # GRPO configuration
    num_generations=2,
    max_completion_length=16384,
    log_completions=True,
    num_completions_to_print=2,
    chat_template_kwargs={"enable_thinking": False},

    # vLLM configuration
    use_vllm=True,
    vllm_mode="colocate",
    vllm_gpu_memory_utilization=0.15,

    # Logging / reporting
    output_dir=output_dir,
    report_to="trackio",
    trackio_space_id=output_dir,
    logging_steps=1,
    save_steps=10,
    save_total_limit=1,

    # Hub integration
    push_to_hub=True,

    # Sampling
    temperature=0.8,
    top_k=10,
)

Create the GRPOTrainer and start training

Now we initialize the GRPOTrainer with environment_factory=SudokuEnv.

This tells the trainer to automatically handle the entire interaction loop:

  • It creates a SudokuEnv instance for each episode.
  • It generates model completions, parses tool calls (like place), and steps through the environment.
  • It collects rewards and manages the tool_mask (which tokens are model-generated vs environment-generated) automatically.

No need to write a custom rollout_func or manage tokenization manually.

python
from trl import GRPOTrainer

trainer = GRPOTrainer(
    model=model_name,
    reward_funcs=[
        reward_empty_cell,
        reward_valid_moves,
        reward_repetition,
        reward_progress,
        reward_correct,
    ],
    train_dataset=dataset,
    args=grpo_config,
    environment_factory=SudokuEnv,
)

Show memory stats before training

python
import torch

gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)

print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

And train!

python
trainer_stats = trainer.train()

Show memory stats after training

python
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_training = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
training_memory_percentage = round(used_memory_for_training / max_memory * 100, 3)

print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.")
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_training} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {training_memory_percentage} %.")

Save and push to Hub

python
trainer.save_model(output_dir)
trainer.push_to_hub()

Load the fine-tuned model and run inference

Now let's test our fine-tuned model by loading it and playing a game of Sudoku. We use the same SudokuEnv class to interact with the environment, and generate model responses with standard Transformers inference.

python
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "sergiopaniego/sudoku-grpo-Qwen3-1.7B"  # Replace with your HF username or organization

fine_tuned_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="float32", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
python
import json


def play_sudoku(model, tokenizer):
    env = SudokuEnv()
    initial_observation = env.reset()

    print("Initial observation:")
    print(initial_observation)
    print()

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": "Play Sudoku like an expert."},
    ]
    if initial_observation:
        messages.append({"role": "user", "content": initial_observation})

    for turn in range(20):  # Play up to 20 turns
        if env.done:
            break

        prompt_text = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=False,
            enable_thinking=False,
        )
        model_inputs = tokenizer([prompt_text], return_tensors="pt").to(model.device)
        generated_ids = model.generate(**model_inputs, max_new_tokens=512)
        output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]
        generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)

        print(f"Turn {turn + 1} - Model output: {generated_text}")

        try:
            # Try to parse tool call arguments
            if "place" in generated_text:
                start = generated_text.index("{")
                end = generated_text.rindex("}") + 1
                args = json.loads(generated_text[start:end])
                if "arguments" in args:
                    args = args["arguments"]
                row = int(args.get("row", 0))
                col = int(args.get("col", 0))
                number = int(args.get("number", 0))
            else:
                # Fallback: extract [row col number] pattern
                import re
                match = re.search(r"\[(\d)\s+(\d)\s+(\d)\]", generated_text)
                if match:
                    row, col, number = int(match.group(1)), int(match.group(2)), int(match.group(3))
                else:
                    print("         Could not parse move.")
                    break

            feedback = env.place(row, col, number)
            print(f"         Move: [{row} {col} {number}]")
            print(f"         Progress reward: {env.progress_reward:.2f}")
            print()

            messages.append({"role": "assistant", "content": generated_text})
            messages.append({"role": "user", "content": feedback})
        except Exception as e:
            print(f"         Error: {e}")
            break

    print(f"Game finished! Correct reward: {env.correct_reward}")
    print(f"Progress: {env.progress_reward:.2f}")
    print(f"Done: {env.done}")

Let's play the game!

python
play_sudoku(fine_tuned_model, tokenizer)