scientific-skills/stable-baselines3/references/callbacks.md
This document provides comprehensive information about the callback system in Stable Baselines3 for monitoring and controlling training.
Callbacks are functions called at specific points during training to:
Evaluates the agent periodically and saves the best model.
from stable_baselines3.common.callbacks import EvalCallback
eval_callback = EvalCallback(
eval_env, # Separate evaluation environment
best_model_save_path="./logs/best_model/", # Where to save best model
log_path="./logs/eval/", # Where to save evaluation logs
eval_freq=10000, # Evaluate every N steps
n_eval_episodes=5, # Number of episodes per evaluation
deterministic=True, # Use deterministic actions
render=False, # Render during evaluation
verbose=1,
warn=True,
)
model.learn(total_timesteps=100000, callback=eval_callback)
Key Features:
Important: When using vectorized training environments, adjust eval_freq:
# With 4 parallel environments, divide eval_freq by n_envs
eval_freq = 10000 // 4 # Evaluate every 10000 total environment steps
Saves model checkpoints at regular intervals.
from stable_baselines3.common.callbacks import CheckpointCallback
checkpoint_callback = CheckpointCallback(
save_freq=10000, # Save every N steps
save_path="./logs/checkpoints/", # Directory for checkpoints
name_prefix="rl_model", # Prefix for checkpoint files
save_replay_buffer=True, # Save replay buffer (off-policy only)
save_vecnormalize=True, # Save VecNormalize stats
verbose=2,
)
model.learn(total_timesteps=100000, callback=checkpoint_callback)
Output Files:
rl_model_10000_steps.zip - Model at 10k stepsrl_model_20000_steps.zip - Model at 20k stepsImportant: Adjust save_freq for vectorized environments (divide by n_envs).
Stops training when mean reward exceeds a threshold.
from stable_baselines3.common.callbacks import StopTrainingOnRewardThreshold
stop_callback = StopTrainingOnRewardThreshold(
reward_threshold=200, # Stop when mean reward >= 200
verbose=1,
)
# Must be used with EvalCallback
eval_callback = EvalCallback(
eval_env,
callback_on_new_best=stop_callback, # Trigger when new best found
eval_freq=10000,
n_eval_episodes=5,
)
model.learn(total_timesteps=1000000, callback=eval_callback)
Stops training if model doesn't improve for N evaluations.
from stable_baselines3.common.callbacks import StopTrainingOnNoModelImprovement
stop_callback = StopTrainingOnNoModelImprovement(
max_no_improvement_evals=10, # Stop after 10 evals with no improvement
min_evals=20, # Minimum evaluations before stopping
verbose=1,
)
# Use with EvalCallback
eval_callback = EvalCallback(
eval_env,
callback_after_eval=stop_callback,
eval_freq=10000,
)
model.learn(total_timesteps=1000000, callback=eval_callback)
Stops training after a maximum number of episodes.
from stable_baselines3.common.callbacks import StopTrainingOnMaxEpisodes
stop_callback = StopTrainingOnMaxEpisodes(
max_episodes=1000, # Stop after 1000 episodes
verbose=1,
)
model.learn(total_timesteps=1000000, callback=stop_callback)
Displays a progress bar during training (requires tqdm).
from stable_baselines3.common.callbacks import ProgressBarCallback
progress_callback = ProgressBarCallback()
model.learn(total_timesteps=100000, callback=progress_callback)
Output:
100%|██████████| 100000/100000 [05:23<00:00, 309.31it/s]
from stable_baselines3.common.callbacks import BaseCallback
class CustomCallback(BaseCallback):
"""
Custom callback template.
"""
def __init__(self, verbose=0):
super().__init__(verbose)
# Custom initialization
def _init_callback(self) -> None:
"""
Called once when training starts.
Useful for initialization that requires access to model/env.
"""
pass
def _on_training_start(self) -> None:
"""
Called before the first rollout starts.
"""
pass
def _on_rollout_start(self) -> None:
"""
Called before collecting new samples (on-policy algorithms).
"""
pass
def _on_step(self) -> bool:
"""
Called after every step in the environment.
Returns:
bool: If False, training will be stopped.
"""
return True # Continue training
def _on_rollout_end(self) -> None:
"""
Called after rollout ends (on-policy algorithms).
"""
pass
def _on_training_end(self) -> None:
"""
Called at the end of training.
"""
pass
Inside callbacks, you have access to:
self.model: The RL algorithm instanceself.training_env: The training environmentself.n_calls: Number of times _on_step() was calledself.num_timesteps: Total number of environment stepsself.locals: Local variables from the algorithm (varies by algorithm)self.globals: Global variables from the algorithmself.logger: Logger for TensorBoard/CSV loggingself.parent: Parent callback (if used in CallbackList)class LogCustomMetricsCallback(BaseCallback):
"""
Log custom metrics to TensorBoard.
"""
def __init__(self, verbose=0):
super().__init__(verbose)
self.episode_rewards = []
def _on_step(self) -> bool:
# Check if episode ended
if self.locals["dones"][0]:
# Log episode reward
episode_reward = self.locals["infos"][0].get("episode", {}).get("r", 0)
self.episode_rewards.append(episode_reward)
# Log to TensorBoard
self.logger.record("custom/episode_reward", episode_reward)
self.logger.record("custom/mean_reward_last_100",
np.mean(self.episode_rewards[-100:]))
return True
class LinearScheduleCallback(BaseCallback):
"""
Linearly decrease learning rate during training.
"""
def __init__(self, initial_lr=3e-4, final_lr=3e-5, verbose=0):
super().__init__(verbose)
self.initial_lr = initial_lr
self.final_lr = final_lr
def _on_step(self) -> bool:
# Calculate progress (0 to 1)
progress = self.num_timesteps / self.locals["total_timesteps"]
# Linear interpolation
new_lr = self.initial_lr + (self.final_lr - self.initial_lr) * progress
# Update learning rate
for param_group in self.model.policy.optimizer.param_groups:
param_group["lr"] = new_lr
# Log learning rate
self.logger.record("train/learning_rate", new_lr)
return True
class EarlyStoppingCallback(BaseCallback):
"""
Stop training if moving average of rewards doesn't improve.
"""
def __init__(self, check_freq=10000, min_reward=200, window=100, verbose=0):
super().__init__(verbose)
self.check_freq = check_freq
self.min_reward = min_reward
self.window = window
self.rewards = []
def _on_step(self) -> bool:
# Collect episode rewards
if self.locals["dones"][0]:
reward = self.locals["infos"][0].get("episode", {}).get("r", 0)
self.rewards.append(reward)
# Check every check_freq steps
if self.n_calls % self.check_freq == 0 and len(self.rewards) >= self.window:
mean_reward = np.mean(self.rewards[-self.window:])
if self.verbose > 0:
print(f"Mean reward: {mean_reward:.2f}")
if mean_reward >= self.min_reward:
if self.verbose > 0:
print(f"Stopping: reward threshold reached!")
return False # Stop training
return True # Continue training
class SaveBestModelCallback(BaseCallback):
"""
Save model when custom metric is best.
"""
def __init__(self, check_freq=1000, save_path="./best_model/", verbose=0):
super().__init__(verbose)
self.check_freq = check_freq
self.save_path = save_path
self.best_score = -np.inf
def _init_callback(self) -> None:
if self.save_path is not None:
os.makedirs(self.save_path, exist_ok=True)
def _on_step(self) -> bool:
if self.n_calls % self.check_freq == 0:
# Calculate custom metric (example: policy entropy)
custom_metric = self.locals.get("entropy_losses", [0])[-1]
if custom_metric > self.best_score:
self.best_score = custom_metric
if self.verbose > 0:
print(f"New best! Saving model to {self.save_path}")
self.model.save(os.path.join(self.save_path, "best_model"))
return True
class EnvironmentInfoCallback(BaseCallback):
"""
Log custom info from environment.
"""
def _on_step(self) -> bool:
# Access info dict from environment
info = self.locals["infos"][0]
# Log custom metrics from environment
if "distance_to_goal" in info:
self.logger.record("env/distance_to_goal", info["distance_to_goal"])
if "success" in info:
self.logger.record("env/success_rate", info["success"])
return True
Use CallbackList to combine multiple callbacks:
from stable_baselines3.common.callbacks import CallbackList
callback_list = CallbackList([
eval_callback,
checkpoint_callback,
progress_callback,
custom_callback,
])
model.learn(total_timesteps=100000, callback=callback_list)
Or pass a list directly:
model.learn(
total_timesteps=100000,
callback=[eval_callback, checkpoint_callback, custom_callback]
)
Callbacks can trigger other callbacks on specific events:
from stable_baselines3.common.callbacks import EventCallback
# Stop training when reward threshold reached
stop_callback = StopTrainingOnRewardThreshold(reward_threshold=200)
# Evaluate periodically and trigger stop_callback when new best found
eval_callback = EvalCallback(
eval_env,
callback_on_new_best=stop_callback, # Triggered when new best model
eval_freq=10000,
)
Use self.logger.record() to log metrics:
class TensorBoardCallback(BaseCallback):
def _on_step(self) -> bool:
# Log scalar
self.logger.record("custom/my_metric", value)
# Log multiple metrics
self.logger.record("custom/metric1", value1)
self.logger.record("custom/metric2", value2)
# Logger automatically writes to TensorBoard
return True
View in TensorBoard:
tensorboard --logdir ./logs/
class CurriculumCallback(BaseCallback):
"""
Increase task difficulty over time.
"""
def __init__(self, difficulty_schedule, verbose=0):
super().__init__(verbose)
self.difficulty_schedule = difficulty_schedule
def _on_step(self) -> bool:
# Update environment difficulty based on progress
progress = self.num_timesteps / self.locals["total_timesteps"]
for threshold, difficulty in self.difficulty_schedule:
if progress >= threshold:
self.training_env.env_method("set_difficulty", difficulty)
return True
class PopulationBasedCallback(BaseCallback):
"""
Adjust hyperparameters based on performance.
"""
def __init__(self, check_freq=10000, verbose=0):
super().__init__(verbose)
self.check_freq = check_freq
self.performance_history = []
def _on_step(self) -> bool:
if self.n_calls % self.check_freq == 0:
# Evaluate performance
perf = self._evaluate_performance()
self.performance_history.append(perf)
# Adjust hyperparameters if performance plateaus
if len(self.performance_history) >= 3:
recent = self.performance_history[-3:]
if max(recent) - min(recent) < 0.01: # Plateau detected
self._adjust_hyperparameters()
return True
def _adjust_hyperparameters(self):
# Example: increase learning rate
for param_group in self.model.policy.optimizer.param_groups:
param_group["lr"] *= 1.2
class DebugCallback(BaseCallback):
def _on_step(self) -> bool:
if self.n_calls == 1:
print("Available in self.locals:")
for key in self.locals.keys():
print(f" {key}: {type(self.locals[key])}")
return True
Callback not being called:
model.learn()_on_step() returns TrueAttributeError in callback:
self.locals.get("key", default) for safetyMemory leaks:
Performance impact:
_on_step() (called every step)check_freq to limit expensive operationsUse appropriate callback timing:
_on_step(): For metrics that change every step_on_rollout_end(): For metrics computed over rollouts_init_callback(): For one-time initializationLog efficiently:
Handle vectorized environments:
dones, infos, etc. are arraysdones[i] for each environmentTest callbacks independently:
Document custom callbacks: