Back to Ml Agents

ML-Agents run with Stable Baselines 3

colab/Colab_UnityEnvironment_4_SB3VectorEnv.ipynb

0.15.18.7 KB
Original Source

ML-Agents run with Stable Baselines 3

Setup

python
#@title Install Rendering Dependencies { display-mode: "form" }
#@markdown (You only need to run this code when using Colab's hosted runtime)

import os
from IPython.display import HTML, display

def progress(value, max=100):
    return HTML("""
        <progress
            value='{value}'
            max='{max}',
            style='width: 100%'
        >
            {value}
        </progress>
    """.format(value=value, max=max))

pro_bar = display(progress(0, 100), display_id=True)

try:
  import google.colab
  INSTALL_XVFB = True
except ImportError:
  INSTALL_XVFB = 'COLAB_ALWAYS_INSTALL_XVFB' in os.environ

if INSTALL_XVFB:
  with open('frame-buffer', 'w') as writefile:
    writefile.write("""#taken from https://gist.github.com/jterrace/2911875
XVFB=/usr/bin/Xvfb
XVFBARGS=":1 -screen 0 1024x768x24 -ac +extension GLX +render -noreset"
PIDFILE=./frame-buffer.pid
case "$1" in
  start)
    echo -n "Starting virtual X frame buffer: Xvfb"
    /sbin/start-stop-daemon --start --quiet --pidfile $PIDFILE --make-pidfile --background --exec $XVFB -- $XVFBARGS
    echo "."
    ;;
  stop)
    echo -n "Stopping virtual X frame buffer: Xvfb"
    /sbin/start-stop-daemon --stop --quiet --pidfile $PIDFILE
    rm $PIDFILE
    echo "."
    ;;
  restart)
    $0 stop
    $0 start
    ;;
  *)
        echo "Usage: /etc/init.d/xvfb {start|stop|restart}"
        exit 1
esac
exit 0
    """)
  !sudo apt-get update
  pro_bar.update(progress(10, 100))
  !sudo DEBIAN_FRONTEND=noninteractive apt install -y daemon wget gdebi-core build-essential libfontenc1 libfreetype6 xorg-dev xorg
  pro_bar.update(progress(20, 100))
  !wget http://security.ubuntu.com/ubuntu/pool/main/libx/libxfont/libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb 2>&1
  pro_bar.update(progress(30, 100))
  !wget --output-document xvfb.deb http://security.ubuntu.com/ubuntu/pool/universe/x/xorg-server/xvfb_1.18.4-0ubuntu0.12_amd64.deb 2>&1
  pro_bar.update(progress(40, 100))
  !sudo dpkg -i libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb 2>&1
  pro_bar.update(progress(50, 100))
  !sudo dpkg -i xvfb.deb 2>&1
  pro_bar.update(progress(70, 100))
  !rm libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb
  pro_bar.update(progress(80, 100))
  !rm xvfb.deb
  pro_bar.update(progress(90, 100))
  !bash frame-buffer start
  os.environ["DISPLAY"] = ":1"
pro_bar.update(progress(100, 100))

Installing ml-agents

python
try:
  import mlagents
  print("ml-agents already installed")
except ImportError:
  !python -m pip install -q mlagents==1.1.0
  print("Installed ml-agents")

Run the Environment

Import dependencies and set some high level parameters.

python
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Any

import gym
from gym import Env

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecMonitor, VecEnv, SubprocVecEnv
from supersuit import observation_lambda_v0


from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.envs.unity_gym_env import UnityToGymWrapper
from mlagents_envs.registry import UnityEnvRegistry, default_registry
from mlagents_envs.side_channel.engine_configuration_channel import (
    EngineConfig,
    EngineConfigurationChannel,
)

NUM_ENVS = 8

Environment and Engine Configurations

python
# Default values from CLI (See cli_utils.py)
DEFAULT_ENGINE_CONFIG = EngineConfig(
    width=84,
    height=84,
    quality_level=4,
    time_scale=20,
    target_frame_rate=-1,
    capture_frame_rate=60,
)

# Some config subset of an actual config.yaml file for MLA.
@dataclass
class LimitedConfig:
    # The local path to a Unity executable or the name of an entry in the registry.
    env_path_or_name: str
    base_port: int
    base_seed: int = 0
    num_env: int = 1
    engine_config: EngineConfig = DEFAULT_ENGINE_CONFIG
    visual_obs: bool = False
    # TODO: Decide if we should just tell users to always use MultiInputPolicy so we can simplify the user workflow.
    # WARNING: Make sure to use MultiInputPolicy if you turn this on.
    allow_multiple_obs: bool = False
    env_registry: UnityEnvRegistry = default_registry

Unity Environment SB3 Factory

python
def _unity_env_from_path_or_registry(
    env: str, registry: UnityEnvRegistry, **kwargs: Any
) -> UnityEnvironment:
    if Path(env).expanduser().absolute().exists():
        return UnityEnvironment(file_name=env, **kwargs)
    elif env in registry:
        return registry.get(env).make(**kwargs)
    else:
        raise ValueError(f"Environment '{env}' wasn't a local path or registry entry")
        
def make_mla_sb3_env(config: LimitedConfig, **kwargs: Any) -> VecEnv:
    def handle_obs(obs, space):
        if isinstance(space, gym.spaces.Tuple):
            if len(space) == 1:
                return obs[0]
            # Turn the tuple into a dict (stable baselines can handle spaces.Dict but not spaces.Tuple).
            return {str(i): v for i, v in enumerate(obs)}
        return obs

    def handle_obs_space(space):
        if isinstance(space, gym.spaces.Tuple):
            if len(space) == 1:
                return space[0]
            # Turn the tuple into a dict (stable baselines can handle spaces.Dict but not spaces.Tuple).
            return gym.spaces.Dict({str(i): v for i, v in enumerate(space)})
        return space

    def create_env(env: str, worker_id: int) -> Callable[[], Env]:
        def _f() -> Env:
            engine_configuration_channel = EngineConfigurationChannel()
            engine_configuration_channel.set_configuration(config.engine_config)
            kwargs["side_channels"] = kwargs.get("side_channels", []) + [
                engine_configuration_channel
            ]
            unity_env = _unity_env_from_path_or_registry(
                env=env,
                registry=config.env_registry,
                worker_id=worker_id,
                base_port=config.base_port,
                seed=config.base_seed + worker_id,
                **kwargs,
            )
            new_env = UnityToGymWrapper(
                unity_env=unity_env,
                uint8_visual=config.visual_obs,
                allow_multiple_obs=config.allow_multiple_obs,
            )
            new_env = observation_lambda_v0(new_env, handle_obs, handle_obs_space)
            return new_env

        return _f

    env_facts = [
        create_env(config.env_path_or_name, worker_id=x) for x in range(config.num_env)
    ]
    return SubprocVecEnv(env_facts)

Start Environment from the registry

python
# -----------------
# This code is used to close an env that might not have been closed before
try:
  env.close()
except:
  pass
# -----------------

env = make_mla_sb3_env(
    config=LimitedConfig(
        env_path_or_name='Basic',  # Can use any name from a registry or a path to your own unity build.
        base_port=6006,
        base_seed=42,
        num_env=NUM_ENVS,
        allow_multiple_obs=True,
    ),
    no_graphics=True,  # Set to false if you are running locally and want to watch the environments move around as they train.
)

Create the model

python
# 250K should train to a reward ~= 0.90 for the "Basic" environment.
# We set the value lower here to demonstrate just a small amount of trianing.
BATCH_SIZE = 32
BUFFER_SIZE = 256
UPDATES = 50
TOTAL_TAINING_STEPS_GOAL = BUFFER_SIZE * UPDATES
BETA = 0.0005
N_EPOCHS = 3 
STEPS_PER_UPDATE = BUFFER_SIZE / NUM_ENVS

# Helps gather stats for our eval() calls later so we can see reward stats.
env = VecMonitor(env)

#Policy and Value function with 2 layers of 128 units each and no shared layers.
policy_kwargs = {"net_arch" : [{"pi": [32,32], "vf": [32,32]}]}

model = PPO(
    "MlpPolicy",
    env,
    verbose=1,
    learning_rate=lambda progress: 0.0003 * (1.0 - progress),
    clip_range=lambda progress: 0.2 * (1.0 - progress),
    clip_range_vf=lambda progress: 0.2 * (1.0 - progress),
    # Uncomment this if you want to log tensorboard results when running this notebook locally.
    # tensorboard_log="results",
    policy_kwargs=policy_kwargs,
    n_steps=int(STEPS_PER_UPDATE),
    batch_size=BATCH_SIZE,
    n_epochs=N_EPOCHS,
    ent_coef=BETA,
)

Train the model

python
# 0.93 is considered solved for the Basic environment
for i in range(UPDATES):
    print(f"Training round {i + 1}/{UPDATES}")
    # NOTE: rest_num_timesteps should only happen the first time so that tensorboard logs are consistent.
    model.learn(total_timesteps=BUFFER_SIZE, reset_num_timesteps=(i == 0))
    model.policy.eval()

Close the environment

Frees up the ports being used.

python
env.close()
print("Closed environment")