examples/notebooks/openenv_sudoku_grpo.ipynb
With Transformers Reinforcement Learning (TRL), you can train a model that learns to play Sudoku, a logic-based number puzzle, through interaction and reinforcement.
An agentic environment is a setting where a model can take actions, observe outcomes, and adjust its behavior based on feedback, similar to how humans learn from trial and error. In this case, the agent interacts with the Sudoku environment through the OpenEnv framework, which standardizes multi-agent and RL-style text environments.
Sudoku is a classic logic-based puzzle where the objective is to fill a 9x9 grid so that each row, column, and 3x3 subgrid contains all digits from 1 to 9 exactly once. This structured yet challenging setup makes Sudoku an excellent benchmark for reasoning and decision-making tasks.
We'll fine-tune a model using GRPO (Group Relative Policy Optimization) via TRL.
Using environment_factory, the trainer automatically handles:
This means you only need to define the environment class and reward functions -- the trainer takes care of the rest.
We'll start by installing TRL (with vLLM support), the OpenEnv Sudoku environment, and trackio for logging.
!pip install -Uq trl[vllm] git+https://huggingface.co/spaces/openenv/sudoku trackio
Log in to your Hugging Face account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your access token on your account settings page.
from huggingface_hub import notebook_login
notebook_login()
This prompt instructs the model on how to play Sudoku. It includes the game rules, board reading format, strategic approaches, and importantly, tells the model to use the place tool to submit moves. The environment_factory pattern uses tool calling to interact with the environment, so the model needs to know which tool to call.
system_prompt = """You are an expert Sudoku player with deep knowledge of logical deduction strategies and number placement techniques.
## GAME RULES
1. The puzzle is a 9x9 grid divided into nine 3x3 subgrids (boxes)
2. Some cells are pre-filled with numbers 1-9
3. You must fill in the empty cells (shown as '.') with numbers 1-9
4. Each row must contain numbers 1-9 without repetition
5. Each column must contain numbers 1-9 without repetition
6. Each 3x3 subgrid must contain numbers 1-9 without repetition
7. You cannot overwrite pre-filled cells
8. Invalid moves result in penalties (-1 reward)
## HOW TO PLAY
Use the `place` tool to make a move. The tool takes three arguments:
- `row`: Row number (1-9)
- `col`: Column number (1-9)
- `number`: The digit to place (1-9)
## STRATEGIC APPROACH
Do not repeat the same move twice.
### Basic Strategies
- **Naked Singles**: If a cell has only one possible candidate, fill it in immediately.
- **Hidden Singles**: If a number can only go in one cell within a row, column, or box, place it there.
- **Scanning**: Look at each row, column, and box to find where specific numbers can go.
### Solving Process
1. Start by scanning the entire grid to identify easy fills (cells with few candidates)
2. Look for rows, columns, or boxes with many numbers already placed
3. Fill all naked singles first
4. Then look for hidden singles in each row, column, and box
### Common Pitfalls to Avoid
- Don't guess randomly - Sudoku is pure logic
- Don't overlook any constraint (row, column, or box)
- Don't try to overwrite pre-filled cells
- Don't place invalid numbers (must be 1-9)
- Don't use invalid coordinates (must be 1-9)
- Don't repeat a move that was already made
## BOARD READING
The board is displayed as a 9x9 grid:
- Numbers 1-9 are pre-filled or already placed
- Empty cells are shown as '.'
- Rows are labeled R1-R9 (top to bottom)
- Columns are labeled C1-C9 (left to right)
## IMPORTANT CONSTRAINTS
- Coordinates are 1-indexed (1-9 for both row and column)
- Numbers must be 1-9
- One move per response
- Must be a valid move (no rule violations)
- Never repeat a previous move
"""
The SudokuEnv class wraps the OpenEnv TextArena Sudoku environment into the interface expected by environment_factory.
When you pass environment_factory=SudokuEnv to the trainer, it will:
SudokuEnv() instance for each rollout episode.reset() to start a new game (returns the initial board state).place(row, col, number) method.done=True or the max completion length is reached.The environment tracks multiple reward signals as properties:
For this example, we connect to the hosted environment at openenv/sudoku. For production use, we recommend duplicating the Space to your own account or running it locally via Docker, as the hosted versions have limited concurrency.
For more information, refer to the TRL-OpenEnv documentation.
# @title SudokuEnv class (click to expand)
from collections import defaultdict
from textarena_env import TextArenaAction, TextArenaEnv
def _is_valid_board_state(board_str: str) -> bool:
return "R1" in board_str and "R9" in board_str and "|" in board_str
def _parse_board(board_str: str) -> list[list[int]]:
grid = [[0] * 9 for _ in range(9)]
if not _is_valid_board_state(board_str):
return grid
for line in board_str.split("\n"):
line_stripped = line.strip()
if line_stripped and line_stripped[0] == "R" and len(line_stripped) > 1 and line_stripped[1].isdigit():
row = int(line_stripped[1]) - 1
cell_part = line_stripped[2:]
col = 0
for char in cell_part:
if char == ".":
grid[row][col] = 0
col += 1
elif char.isdigit():
grid[row][col] = int(char)
col += 1
return grid
def _count_filled_cells(board_str: str) -> int:
if not _is_valid_board_state(board_str):
return 0
grid = _parse_board(board_str)
return sum(1 for row in grid for cell in row if cell != 0)
def _get_valid_numbers(grid: list[list[int]], row: int, col: int) -> set[int]:
if grid[row][col] != 0:
return set()
used = set()
for c in range(9):
if grid[row][c] != 0:
used.add(grid[row][c])
for r in range(9):
if grid[r][col] != 0:
used.add(grid[r][col])
box_row, box_col = 3 * (row // 3), 3 * (col // 3)
for r in range(box_row, box_row + 3):
for c in range(box_col, box_col + 3):
if grid[r][c] != 0:
used.add(grid[r][c])
return set(range(1, 10)) - used
def _extract_empty_cells_with_candidates(board_str: str, sort_by_difficulty: bool = True):
grid = _parse_board(board_str)
cells_with_candidates = []
for row in range(9):
for col in range(9):
if grid[row][col] == 0:
candidates = _get_valid_numbers(grid, row, col)
cells_with_candidates.append((row + 1, col + 1, candidates))
if sort_by_difficulty:
cells_with_candidates.sort(key=lambda x: len(x[2]))
return cells_with_candidates
def _extract_empty_cells(board_str: str) -> list[tuple[int, int]]:
empty_cells = []
if not _is_valid_board_state(board_str):
return empty_cells
for line in board_str.split("\n"):
line_stripped = line.strip()
if line_stripped and line_stripped[0] == "R" and len(line_stripped) > 1 and line_stripped[1].isdigit():
row = int(line_stripped[1])
cell_part = line_stripped[2:]
col = 0
for char in cell_part:
if char == ".":
col += 1
empty_cells.append((row, col))
elif char.isdigit():
col += 1
return empty_cells
def _extract_board_only(text: str) -> str:
if not text:
return ""
lines = text.split("\n")
board_lines = []
in_board = False
for line in lines:
stripped = line.strip()
if stripped.startswith("C1") or (
stripped and stripped[0] == "R" and len(stripped) > 1 and stripped[1].isdigit()
):
in_board = True
if in_board and (stripped.startswith("-") or stripped.startswith("R") or stripped.startswith("C1")):
board_lines.append(line)
elif (
in_board
and stripped
and not stripped.startswith("-")
and not (stripped[0] == "R" and len(stripped) > 1 and stripped[1].isdigit())
):
break
return "\n".join(board_lines) if board_lines else ""
def _make_hints(board, successful_moves, failed_moves, difficulty="easy"):
"""Generate hint text for the model."""
parts = []
all_tried = successful_moves + failed_moves
if all_tried:
parts.append(f"\nMOVES ALREADY TRIED (do not repeat): {', '.join(all_tried)}")
if not board:
return "\n".join(parts)
if difficulty == "easy":
cells = _extract_empty_cells_with_candidates(board, sort_by_difficulty=True)
if cells:
guaranteed = []
other = []
for r, c, candidates in cells[:10]:
if len(candidates) == 1:
guaranteed.append(f"[{r} {c} {list(candidates)[0]}]")
elif len(candidates) <= 3:
nums = ",".join(str(n) for n in sorted(candidates))
other.append(f"({r},{c})->{nums}")
if guaranteed:
parts.append(f"\nGUARANTEED MOVES: {', '.join(guaranteed[:5])}")
if other:
parts.append(f"Other options: {' | '.join(other[:5])}")
elif difficulty == "medium":
cells = _extract_empty_cells_with_candidates(board, sort_by_difficulty=False)
if cells:
cell_hints = []
for r, c, candidates in cells[:10]:
nums = ",".join(str(n) for n in sorted(candidates))
cell_hints.append(f"({r},{c})->{nums}")
parts.append(f"\nEmpty cells: {' | '.join(cell_hints)}")
return "\n".join(parts)
class SudokuEnv:
def __init__(self):
self.client = TextArenaEnv(base_url="https://openenv-sudoku.hf.space")
self.difficulty = "easy"
self.max_turns = 100
self._turn = 0
self._move_counts = defaultdict(int)
self._successful_moves = []
self._failed_moves = []
self._valid_move_scores = []
self._empty_cell_scores = []
self._correct_scores = []
self._repetition_scores = []
self._last_board_state = ""
self._last_full_content = ""
self._initial_filled = 0
self._max_filled = 0
self.done = False
def reset(self, **kwargs) -> str | None:
result = self.client.reset()
observation = result.observation
self.done = False
self._turn = 0
self._move_counts = defaultdict(int)
self._successful_moves = []
self._failed_moves = []
self._valid_move_scores = []
self._empty_cell_scores = []
self._correct_scores = []
self._repetition_scores = []
self._last_board_state = ""
self._initial_filled = 0
self._max_filled = 0
# Store full message content for diffing (messages are cumulative)
self._last_full_content = observation.messages[0].content if observation.messages else ""
for message in observation.messages:
if message.content and _is_valid_board_state(message.content):
self._last_board_state = message.content
self._initial_filled = _count_filled_cells(self._last_board_state)
self._max_filled = self._initial_filled
break
board = _extract_board_only(self._last_board_state) if self._last_board_state else "No board available."
hints = _make_hints(self._last_board_state, [], [], self.difficulty)
return f"Step 0. Progress: 0 cells filled.\n\nBoard:\n{board}{hints}"
def place(self, row: int, col: int, number: int) -> str:
"""Place a number on the Sudoku board.
Args:
row: Row number (1-9).
col: Column number (1-9).
number: Number to place (1-9).
Returns:
The result of the move and updated board state.
"""
if self.done:
return "Game is over. No more moves allowed."
self._turn += 1
move = f"[{row} {col} {number}]"
# Step environment
result = self.client.step(TextArenaAction(message=move))
observation = result.observation
correct_score = float(result.reward or 0.0)
self.done = result.done
# Only check the NEW content for feedback (messages are cumulative)
full_content = observation.messages[0].content if observation.messages else ""
new_content = full_content[len(self._last_full_content):]
self._last_full_content = full_content
new_content_lower = new_content.lower()
env_says_invalid = any(
kw in new_content_lower for kw in ["invalid", "error", "cannot", "already", "violation", "lost"]
)
got_warning = "please resubmit" in new_content_lower or "avoid penalties" in new_content_lower
# Also verify against our own board state: placing on a non-empty cell is always invalid
if self._last_board_state:
empty_cells = _extract_empty_cells(self._last_board_state)
targets_empty = (row, col) in empty_cells
else:
empty_cells = []
targets_empty = True # Can't verify, assume valid
is_valid = not env_says_invalid and targets_empty
# Empty cell score: did the model target an empty cell?
empty_cell_score = 1.0 if targets_empty else -1.0
# Repetition tracking
is_new_move = self._move_counts[move] == 0
repetition_count = self._move_counts[move]
self._move_counts[move] += 1
repetition_score = -min(2 ** (repetition_count - 1), 10.0) if repetition_count > 0 else 0.0
# Valid move score
if is_valid and is_new_move:
valid_move_score = 1.0
self._successful_moves.append(move)
elif got_warning:
valid_move_score = -0.5
self._failed_moves.append(move)
else:
valid_move_score = 0.0
# Update board state from new content
if is_valid and _is_valid_board_state(new_content):
self._last_board_state = new_content
current_filled = _count_filled_cells(self._last_board_state)
if current_filled > self._max_filled:
self._max_filled = current_filled
self._valid_move_scores.append(valid_move_score)
self._empty_cell_scores.append(empty_cell_score)
self._correct_scores.append(correct_score)
self._repetition_scores.append(repetition_score)
# Enforce max turns
if self._turn >= self.max_turns:
self.done = True
# Build response
board = _extract_board_only(self._last_board_state) if self._last_board_state else "No board available."
status = "valid" if is_valid else "invalid"
cells_filled = len(self._successful_moves)
progress = f"Step {self._turn}. Progress: {cells_filled} cells filled."
hints = _make_hints(self._last_board_state, self._successful_moves, self._failed_moves, self.difficulty)
if self.done:
return f"Move {move}: {status}. Game over.\n{progress}\n\nFinal board:\n{board}"
return f"Move {move}: {status}\n{progress}\n\nBoard:\n{board}{hints}"
# ── Reward properties ──
@property
def correct_reward(self) -> float:
return self._correct_scores[-1] if self._correct_scores else 0.0
@property
def valid_move_reward(self) -> float:
return sum(self._valid_move_scores) / len(self._valid_move_scores) if self._valid_move_scores else 0.0
@property
def empty_cell_reward(self) -> float:
return sum(self._empty_cell_scores) / len(self._empty_cell_scores) if self._empty_cell_scores else 0.0
@property
def repetition_reward(self) -> float:
return sum(self._repetition_scores) / len(self._repetition_scores) if self._repetition_scores else 0.0
@property
def progress_reward(self) -> float:
remaining_to_fill = 81 - self._initial_filled
if remaining_to_fill > 0:
return (self._max_filled - self._initial_filled) / remaining_to_fill
return 1.0
The reward functions receive the list of environment instances after each episode completes. Since the SudokuEnv tracks multiple reward signals as properties, we simply read them out.
Each reward function captures a different aspect of play quality:
def reward_empty_cell(environments, **kwargs) -> list[float]:
"""Reward for targeting empty cells."""
return [env.empty_cell_reward for env in environments]
def reward_valid_moves(environments, **kwargs) -> list[float]:
"""Reward for making valid moves."""
return [env.valid_move_reward for env in environments]
def reward_repetition(environments, **kwargs) -> list[float]:
"""Penalty for repeating moves."""
return [env.repetition_reward for env in environments]
def reward_progress(environments, **kwargs) -> list[float]:
"""Reward for filling more cells in the board."""
return [env.progress_reward for env in environments]
def reward_correct(environments, **kwargs) -> list[float]:
"""Reward for solving the puzzle."""
return [env.correct_reward for env in environments]
We create a dataset with repeated prompts to control the number of training episodes. Each entry triggers one rollout episode during training. The prompt is formatted as a chat message with the system prompt.
from datasets import Dataset
dataset_size = 3000
dataset = Dataset.from_dict({
"prompt": [[
{"role": "system", "content": system_prompt},
{"role": "user", "content": "Play Sudoku like an expert."},
]] * dataset_size
})
Next, we define the GRPOConfig, which controls all key training parameters. This configuration specifies how the model interacts with vLLM, manages memory, and logs results.
Note the chat_template_kwargs={"enable_thinking": False} parameter -- this disables Qwen3's thinking mode so the model responds directly with tool calls instead of generating internal reasoning tokens first.
from trl import GRPOConfig
model_name = "Qwen/Qwen3-1.7B"
output_dir = "sudoku-grpo-Qwen3-1.7B"
grpo_config = GRPOConfig(
# Training schedule / optimization
num_train_epochs=1,
learning_rate=5e-6,
gradient_accumulation_steps=64,
per_device_train_batch_size=1,
warmup_steps=20,
optim="adamw_torch",
max_grad_norm=1.0,
# GRPO configuration
num_generations=2,
max_completion_length=16384,
log_completions=True,
num_completions_to_print=2,
chat_template_kwargs={"enable_thinking": False},
# vLLM configuration
use_vllm=True,
vllm_mode="colocate",
vllm_gpu_memory_utilization=0.15,
# Logging / reporting
output_dir=output_dir,
report_to="trackio",
trackio_space_id=output_dir,
logging_steps=1,
save_steps=10,
save_total_limit=1,
# Hub integration
push_to_hub=True,
# Sampling
temperature=0.8,
top_k=10,
)
GRPOTrainer and start trainingNow we initialize the GRPOTrainer with environment_factory=SudokuEnv.
This tells the trainer to automatically handle the entire interaction loop:
SudokuEnv instance for each episode.place), and steps through the environment.tool_mask (which tokens are model-generated vs environment-generated) automatically.No need to write a custom rollout_func or manage tokenization manually.
from trl import GRPOTrainer
trainer = GRPOTrainer(
model=model_name,
reward_funcs=[
reward_empty_cell,
reward_valid_moves,
reward_repetition,
reward_progress,
reward_correct,
],
train_dataset=dataset,
args=grpo_config,
environment_factory=SudokuEnv,
)
Show memory stats before training
import torch
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")
And train!
trainer_stats = trainer.train()
Show memory stats after training
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_training = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
training_memory_percentage = round(used_memory_for_training / max_memory * 100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.")
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_training} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {training_memory_percentage} %.")
trainer.save_model(output_dir)
trainer.push_to_hub()
Now let's test our fine-tuned model by loading it and playing a game of Sudoku.
We use the same SudokuEnv class to interact with the environment, and generate model responses with standard Transformers inference.
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "sergiopaniego/sudoku-grpo-Qwen3-1.7B" # Replace with your HF username or organization
fine_tuned_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="float32", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
import json
def play_sudoku(model, tokenizer):
env = SudokuEnv()
initial_observation = env.reset()
print("Initial observation:")
print(initial_observation)
print()
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": "Play Sudoku like an expert."},
]
if initial_observation:
messages.append({"role": "user", "content": initial_observation})
for turn in range(20): # Play up to 20 turns
if env.done:
break
prompt_text = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False,
enable_thinking=False,
)
model_inputs = tokenizer([prompt_text], return_tensors="pt").to(model.device)
generated_ids = model.generate(**model_inputs, max_new_tokens=512)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]
generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)
print(f"Turn {turn + 1} - Model output: {generated_text}")
try:
# Try to parse tool call arguments
if "place" in generated_text:
start = generated_text.index("{")
end = generated_text.rindex("}") + 1
args = json.loads(generated_text[start:end])
if "arguments" in args:
args = args["arguments"]
row = int(args.get("row", 0))
col = int(args.get("col", 0))
number = int(args.get("number", 0))
else:
# Fallback: extract [row col number] pattern
import re
match = re.search(r"\[(\d)\s+(\d)\s+(\d)\]", generated_text)
if match:
row, col, number = int(match.group(1)), int(match.group(2)), int(match.group(3))
else:
print(" Could not parse move.")
break
feedback = env.place(row, col, number)
print(f" Move: [{row} {col} {number}]")
print(f" Progress reward: {env.progress_reward:.2f}")
print()
messages.append({"role": "assistant", "content": generated_text})
messages.append({"role": "user", "content": feedback})
except Exception as e:
print(f" Error: {e}")
break
print(f"Game finished! Correct reward: {env.correct_reward}")
print(f"Progress: {env.progress_reward:.2f}")
print(f"Done: {env.done}")
Let's play the game!
play_sudoku(fine_tuned_model, tokenizer)