scientific-skills/stable-baselines3/references/custom_environments.md
This guide provides comprehensive information for creating custom Gymnasium environments compatible with Stable Baselines3.
Every custom environment must inherit from gymnasium.Env and implement:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
class CustomEnv(gym.Env):
def __init__(self):
"""Initialize environment, define action_space and observation_space"""
super().__init__()
self.action_space = spaces.Discrete(4)
self.observation_space = spaces.Box(low=0, high=1, shape=(4,), dtype=np.float32)
def reset(self, seed=None, options=None):
"""Reset environment to initial state"""
super().reset(seed=seed)
observation = self.observation_space.sample()
info = {}
return observation, info
def step(self, action):
"""Execute one timestep"""
observation = self.observation_space.sample()
reward = 0.0
terminated = False # Episode ended naturally
truncated = False # Episode ended due to time limit
info = {}
return observation, reward, terminated, truncated, info
def render(self):
"""Visualize environment (optional)"""
pass
def close(self):
"""Cleanup resources (optional)"""
pass
__init__(self, ...)Purpose: Initialize the environment and define spaces.
Requirements:
super().__init__()self.action_spaceself.observation_spaceExample:
def __init__(self, grid_size=10, max_steps=100):
super().__init__()
self.grid_size = grid_size
self.max_steps = max_steps
self.current_step = 0
# Define spaces
self.action_space = spaces.Discrete(4)
self.observation_space = spaces.Box(
low=0, high=grid_size-1, shape=(2,), dtype=np.float32
)
reset(self, seed=None, options=None)Purpose: Reset the environment to an initial state.
Requirements:
super().reset(seed=seed)(observation, info) tupleobservation_spaceExample:
def reset(self, seed=None, options=None):
super().reset(seed=seed)
# Initialize state
self.agent_pos = self.np_random.integers(0, self.grid_size, size=2)
self.goal_pos = self.np_random.integers(0, self.grid_size, size=2)
self.current_step = 0
observation = self._get_observation()
info = {"episode": "started"}
return observation, info
step(self, action)Purpose: Execute one timestep in the environment.
Requirements:
(observation, reward, terminated, truncated, info)action_spaceobservation_spaceExample:
def step(self, action):
# Apply action
self.agent_pos += self._action_to_direction(action)
self.agent_pos = np.clip(self.agent_pos, 0, self.grid_size - 1)
self.current_step += 1
# Calculate reward
distance = np.linalg.norm(self.agent_pos - self.goal_pos)
goal_reached = distance < 1.0
if goal_reached:
reward = 100.0
else:
reward = -distance * 0.1
# Check termination conditions
terminated = goal_reached
truncated = self.current_step >= self.max_steps
observation = self._get_observation()
info = {"distance": distance, "steps": self.current_step}
return observation, reward, terminated, truncated, info
For discrete actions (e.g., {0, 1, 2, 3}).
self.action_space = spaces.Discrete(4) # 4 actions: 0, 1, 2, 3
Important: SB3 does NOT support Discrete spaces with start != 0. Always start from 0.
For continuous values within a range.
# 1D continuous action in [-1, 1]
self.action_space = spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32)
# 2D position observation
self.observation_space = spaces.Box(
low=0, high=10, shape=(2,), dtype=np.float32
)
# 3D RGB image (channel-first format)
self.observation_space = spaces.Box(
low=0, high=255, shape=(3, 84, 84), dtype=np.uint8
)
Important for Images:
dtype=np.uint8 in range [0, 255]normalize_images=False in policy_kwargs if pre-normalizedFor multiple discrete variables.
# Two discrete variables: first with 3 options, second with 4 options
self.action_space = spaces.MultiDiscrete([3, 4])
For binary vectors.
# 5 binary flags
self.action_space = spaces.MultiBinary(5) # e.g., [0, 1, 1, 0, 1]
For dictionary observations (e.g., combining image with sensors).
self.observation_space = spaces.Dict({
"image": spaces.Box(low=0, high=255, shape=(3, 64, 64), dtype=np.uint8),
"vector": spaces.Box(low=-10, high=10, shape=(4,), dtype=np.float32),
"discrete": spaces.Discrete(3),
})
Important: When using Dict observations, use "MultiInputPolicy" instead of "MlpPolicy".
model = PPO("MultiInputPolicy", env, verbose=1)
For tuple observations (less common).
self.observation_space = spaces.Tuple((
spaces.Box(low=0, high=1, shape=(4,), dtype=np.float32),
spaces.Discrete(3),
))
np.float32 for continuous valuesnp.uint8 in range [0, 255]np.float32Always use self.np_random for reproducibility:
def reset(self, seed=None, options=None):
super().reset(seed=seed)
# Use self.np_random instead of np.random
random_pos = self.np_random.integers(0, 10, size=2)
random_float = self.np_random.random()
def step(self, action):
# ... environment logic ...
goal_reached = self._check_goal()
time_limit_exceeded = self.current_step >= self.max_steps
terminated = goal_reached # Natural ending
truncated = time_limit_exceeded # Time limit
return observation, reward, terminated, truncated, info
Use the info dict for debugging and logging:
info = {
"episode_length": self.current_step,
"distance_to_goal": distance,
"success": goal_reached,
"total_reward": self.cumulative_reward,
}
Special Keys:
"terminal_observation": Automatically added by VecEnv when episode endsProvide rendering information:
class CustomEnv(gym.Env):
metadata = {
"render_modes": ["human", "rgb_array"],
"render_fps": 30,
}
def __init__(self, render_mode=None):
super().__init__()
self.render_mode = render_mode
# ...
def render(self):
if self.render_mode == "human":
# Print or display for human viewing
print(f"Agent at {self.agent_pos}")
elif self.render_mode == "rgb_array":
# Return numpy array (height, width, 3) for video recording
canvas = np.zeros((500, 500, 3), dtype=np.uint8)
# Draw environment on canvas
return canvas
For Hindsight Experience Replay, use specific observation structure:
self.observation_space = spaces.Dict({
"observation": spaces.Box(low=-10, high=10, shape=(3,), dtype=np.float32),
"achieved_goal": spaces.Box(low=-10, high=10, shape=(3,), dtype=np.float32),
"desired_goal": spaces.Box(low=-10, high=10, shape=(3,), dtype=np.float32),
})
def compute_reward(self, achieved_goal, desired_goal, info):
"""Required for HER environments"""
distance = np.linalg.norm(achieved_goal - desired_goal)
return -distance
Always validate your environment before training:
from stable_baselines3.common.env_checker import check_env
env = CustomEnv()
check_env(env, warn=True)
Common Validation Errors:
"Observation is not within bounds"
"Reset should return tuple"
(observation, info), not just observation"Step should return 5-tuple"
(obs, reward, terminated, truncated, info)"Action is out of bounds"
"Observation/Action dtype mismatch"
Register your environment with Gymnasium:
import gymnasium as gym
from gymnasium.envs.registration import register
register(
id="MyCustomEnv-v0",
entry_point="my_module:CustomEnv",
max_episode_steps=200,
kwargs={"grid_size": 10}, # Default kwargs
)
# Now can use with gym.make
env = gym.make("MyCustomEnv-v0")
def test_environment(env, n_episodes=5):
"""Test environment with random actions"""
for episode in range(n_episodes):
obs, info = env.reset()
episode_reward = 0
done = False
steps = 0
while not done:
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
episode_reward += reward
steps += 1
done = terminated or truncated
print(f"Episode {episode+1}: Reward={episode_reward:.2f}, Steps={steps}")
from stable_baselines3 import PPO
def train_test(env, timesteps=10000):
"""Quick training test"""
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=timesteps)
# Evaluate
obs, info = env.reset()
for _ in range(100):
action, _states = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
break
class GridWorldEnv(gym.Env):
def __init__(self, size=10):
super().__init__()
self.size = size
self.action_space = spaces.Discrete(4) # up, down, left, right
self.observation_space = spaces.Box(0, size-1, shape=(2,), dtype=np.float32)
class ContinuousEnv(gym.Env):
def __init__(self):
super().__init__()
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(8,), dtype=np.float32)
class VisionEnv(gym.Env):
def __init__(self):
super().__init__()
self.action_space = spaces.Discrete(4)
# Channel-first: (channels, height, width)
self.observation_space = spaces.Box(
low=0, high=255, shape=(3, 84, 84), dtype=np.uint8
)
class MultiModalEnv(gym.Env):
def __init__(self):
super().__init__()
self.action_space = spaces.Discrete(4)
self.observation_space = spaces.Dict({
"image": spaces.Box(0, 255, shape=(3, 64, 64), dtype=np.uint8),
"sensors": spaces.Box(-10, 10, shape=(4,), dtype=np.float32),
})
# Pre-allocate arrays
def __init__(self):
# ...
self._obs_buffer = np.zeros(self.observation_space.shape, dtype=np.float32)
def _get_observation(self):
# Reuse buffer instead of allocating new array
self._obs_buffer[0] = self.agent_x
self._obs_buffer[1] = self.agent_y
return self._obs_buffer
Make environment operations vectorizable:
# Good: Uses numpy operations
def step(self, action):
direction = np.array([[0,1], [0,-1], [1,0], [-1,0]])[action]
self.pos = np.clip(self.pos + direction, 0, self.size-1)
# Avoid: Python loops when possible
# for i in range(len(self.agents)):
# self.agents[i].update()
assert np.isfinite(reward)VecCheckNan wrapper to catch issuescheck_env()scripts/custom_env_template.py