python/rollout.ipynb
This notebook provides a tutorial for MuJoCo physics, using the native Python bindings.
This notebook describes the rollout module included in the MuJoCo Python library. It performs simulation "rollouts" with an underlying C++ function. The rollouts can be multithreaded.
Below, the usage of each argument is explained with examples. Then some examples for advanced use cases are provided. Finally, rollout is benchmarked against pure python and MJX.
Note the benchmarks were designed to run on >16 thread CPU and an RTX 4090 or A100. They do not run in a reasonable amount of time on a typical free colab runtime.
<!-- Copyright 2025 DeepMind Technologies Limited Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -->!pip install mujoco
!pip install mujoco_mjx
!pip install brax
# Set up GPU rendering.
#from google.colab import files
import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
raise RuntimeError(
'Cannot communicate with GPU. '
'Make sure you are using a GPU Colab runtime. '
'Go to the Runtime menu and select Choose runtime type.')
# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
f.write("""{
"file_format_version" : "1.0.0",
"ICD" : {
"library_path" : "libEGL_nvidia.so.0"
}
}
""")
# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl
# Check if installation was successful.
try:
print('Checking that the installation succeeded:')
import mujoco
from mujoco import rollout
from mujoco import mjx
mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
raise e from RuntimeError(
'Something went wrong during installation. Check the shell output above '
'for more information.\n'
'If using a hosted Colab runtime, make sure you enable GPU acceleration '
'by going to the Runtime menu and selecting "Choose runtime type".')
print('Installation successful.')
# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags
# Other imports and helper functions
import copy
import time
from multiprocessing import cpu_count
import threading
import numpy as np
import jax
import jax.numpy as jp
# Graphics and plotting.
print('Installing mediapy:')
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
import mediapy as media
import matplotlib
import matplotlib.pyplot as plt
# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)
# Set the number of threads to the number of cpu's that the multiprocessing module reports
nthread = cpu_count()
# Get MuJoCo's standard humanoid and humanoid_100 models.
print('Getting MuJoCo humanoid XML description from GitHub:')
!git clone https://github.com/google-deepmind/mujoco
humanoid_path = 'mujoco/model/humanoid/humanoid.xml'
humanoid100_path = 'mujoco/model/humanoid/humanoid100.xml'
print('Getting hopper XML description from GitHub:')
!git clone https://github.com/google-deepmind/dm_control
hopper_path ='dm_control/dm_control/suite/hopper.xml'
# clear installation printouts
from IPython.display import clear_output
clear_output()
def get_state(model, data, nbatch=1):
full_physics = mujoco.mjtState.mjSTATE_FULLPHYSICS
state = np.zeros((mujoco.mj_stateSize(model, full_physics),))
mujoco.mj_getState(model, data, state, full_physics)
return np.tile(state, (nbatch, 1))
def xy_grid(nbatch, ncols=10, spacing=0.05):
nrows = nbatch // ncols
assert nbatch == nrows * ncols
xmax = (nrows-1)*spacing/2
rows = np.linspace(-xmax, xmax, nrows)
ymax = (ncols-1)*spacing/2
cols = np.linspace(-ymax, ymax, ncols)
x, y = np.meshgrid(rows, cols)
return np.stack((x.flatten(), y.flatten())).T
def benchmark(f, x_list=[None], ntiming=1, f_init=None):
x_times_list = []
for x in x_list:
times = []
for i in range(ntiming):
if f_init is not None:
x_init = f_init(x)
start = time.perf_counter()
if f_init is not None:
f(x, x_init)
else:
f(x)
end = time.perf_counter()
times.append(end - start)
x_times_list.append(np.mean(times))
return np.array(x_times_list)
def render_many(model, data, state, framerate, camera=-1, shape=(480, 640),
transparent=False, light_pos=None):
nbatch = state.shape[0]
if not isinstance(model, mujoco.MjModel):
model = list(model)
if isinstance(model, list) and len(model) == 1:
model = model * nbatch
elif isinstance(model, list):
assert len(model) == nbatch
else:
model = [model] * nbatch
# Visual options
vopt = mujoco.MjvOption()
vopt.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = transparent
pert = mujoco.MjvPerturb() # Empty MjvPerturb object
catmask = mujoco.mjtCatBit.mjCAT_DYNAMIC
# Simulate and render.
frames = []
with mujoco.Renderer(model[0], *shape) as renderer:
for i in range(state.shape[1]):
if len(frames) < i * model[0].opt.timestep * framerate:
for j in range(state.shape[0]):
mujoco.mj_setState(model[j], data, state[j, i, :],
mujoco.mjtState.mjSTATE_FULLPHYSICS)
mujoco.mj_forward(model[j], data)
# Use first model to make the scene, add subsequent models
if j == 0:
renderer.update_scene(data, camera, scene_option=vopt)
else:
mujoco.mjv_addGeoms(model[j], data, vopt, pert, catmask, renderer.scene)
# Add light, if requested
if light_pos is not None:
light = renderer.scene.lights[renderer.scene.nlight]
light.ambient = [0, 0, 0]
light.attenuation = [1, 0, 0]
light.castshadow = 1
light.cutoff = 45
light.diffuse = [0.8, 0.8, 0.8]
light.dir = [0, 0, -1]
light.type = mujoco.mjtLightType.mjLIGHT_SPOT
light.exponent = 10
light.headlight = 0
light.specular = [0.3, 0.3, 0.3]
light.pos = light_pos
renderer.scene.nlight += 1
# Render and add the frame.
pixels = renderer.render()
frames.append(pixels)
return frames
rolloutThe rollout.rollout function in the mujoco Python library runs batches of simulations for a fixed number steps. It can run in single or multi-threaded modes. The speedup over pure Python is significant because rollout users can easily enable the usage of a lightweight threadpool.
Below we load the "tippe top", "humanoid", and "humanoid100" models which will be used in the following usage examples and benchmarks.
The tippe top is copied from the tutorial notebook. The humanoid and humanoid100 models are distributed with MuJoCo.
#@title Benchmarked models
tippe_top = """
<mujoco model="tippe top">
<option integrator="RK4"/>
<asset>
<texture name="grid" type="2d" builtin="checker" rgb1=".1 .2 .3"
rgb2=".2 .3 .4" width="300" height="300"/>
<material name="grid" texture="grid" texrepeat="40 40" reflectance=".2"/>
</asset>
<worldbody>
<geom size="1 1 .01" type="plane" material="grid"/>
<light pos="0 0 .6"/>
<camera name="closeup" pos="0 -.1 .07" xyaxes="1 0 0 0 1 2"/>
<camera name="distant" pos="0 -.4 .4" xyaxes="1 0 0 0 1 1"/>
<body name="top" pos="0 0 .02">
<freejoint name="top"/>
<site name="top" pos="0 0 0"/>
<geom name="ball" type="sphere" size=".02" />
<geom name="stem" type="cylinder" pos="0 0 .02" size="0.004 .008"/>
<geom name="ballast" type="box" size=".023 .023 0.005" pos="0 0 -.015"
contype="0" conaffinity="0" group="3"/>
</body>
</worldbody>
<sensor>
<gyro name="gyro" site="top"/>
</sensor>
<keyframe>
<key name="spinning" qpos="0 0 0.02 1 0 0 0" qvel="0 0 0 0 1 200" />
</keyframe>
</mujoco>
"""
# Create and initialize top model
top_model = mujoco.MjModel.from_xml_string(tippe_top)
top_data = mujoco.MjData(top_model)
# Set to the state to a spinning top (keyframe 0)
mujoco.mj_resetDataKeyframe(top_model, top_data, 0)
top_state = get_state(top_model, top_data)
# Create and initialize humanoid model
humanoid_model = mujoco.MjModel.from_xml_path(humanoid_path)
humanoid_data = mujoco.MjData(humanoid_model)
humanoid_data.qvel[2] = 4 # Make the humanoid jump
humanoid_state = get_state(humanoid_model, humanoid_data)
# Create and initialize humanoid100 model
humanoid100_model = mujoco.MjModel.from_xml_path(humanoid100_path)
humanoid100_data = mujoco.MjData(humanoid100_model)
h100_state = get_state(humanoid100_model, humanoid100_data)
start = time.time()
top_nstep = int(6 / top_model.opt.timestep)
top_state, _ = rollout.rollout(top_model, top_data, top_state, nstep=top_nstep)
humanoid_nstep = int(3 / humanoid_model.opt.timestep)
humanoid_state, _ = rollout.rollout(humanoid_model, humanoid_data,
humanoid_state, nstep=humanoid_nstep)
humanoid100_nstep = int(3 / humanoid100_model.opt.timestep)
h100_state, _ = rollout.rollout(humanoid100_model, humanoid100_data,
h100_state, nstep=humanoid100_nstep)
end = time.time()
start_render = time.time()
top_frames = render_many(top_model, top_data, top_state, framerate=60, shape=(240, 320))
humanoid_frames = render_many(humanoid_model, humanoid_data, humanoid_state, framerate=120, shape=(240, 320))
humanoid100_frames = render_many(humanoid100_model, humanoid100_data, h100_state, framerate=120, shape=(240, 320))
# humanoid and humanoid100 are shown at half speed
media.show_video(np.concatenate((top_frames, humanoid_frames, humanoid100_frames), axis=2), fps=60)
end_render = time.time()
print(f'Rollout took {end-start:.1f} seconds')
print(f'Rendering took {end_render-start_render:.1f} seconds')
It is helpful to read rollout's docstring before beginning. The main takeaways are that rollout runs nbatch rollouts for nstep steps. Each MjModel can be different but should be the same up to parameter values. Passing multiple MjData enables multithreading, one thread per MjData.
Further documentation can be found here.
Next we give usage examples of the most common arguments. The more advanced arguments are discussed in the "Advanced Usage" section.
print(rollout.rollout.__doc__)
rollout is designed to run nbatch rollouts in parallel for nstep steps. Let's simulate 100 tippe tops with different initial rotation speeds.
Note: Using multithreading with rollout is enabled by passing one MjData per thread, as is done below.
nbatch = 100 # Simulate this many tops
# Get nbatch initial states and scale the initial speed of the tippe top using the batch index
top_data = mujoco.MjData(top_model)
mujoco.mj_resetDataKeyframe(top_model, top_data, 0)
initial_states = get_state(top_model, top_data, nbatch)
initial_states[:, -1] *= np.linspace(0.5, 1.5, num=nbatch)
# Run the rollout
start = time.time()
top_datas = [copy.copy(top_data) for _ in range(nthread)] # 1 MjData per thread
state, sensordata = rollout.rollout(top_model, top_datas, initial_states,
nstep=int(top_nstep*1.5))
end = time.time()
# Use state to render all the tops at once
start_render = time.time()
framerate = 60
frames = render_many(top_model, top_data, state, framerate, transparent=True)
media.show_video(frames, fps=framerate)
end_render = time.time()
print(f'Rollout time {end-start:.1f} seconds')
print(f'Rendering time {end_render-start_render:.1f} seconds')
Our model has an angular velocity sensor the middle of the top. Let's plot the response using the sensordata array that rollout returns.
plt.figure(figsize=(12, 8))
plt.subplot(3,1,1)
for i in range(nbatch): plt.plot(sensordata[i, :, 0])
plt.subplot(3,1,2)
for i in range(nbatch): plt.plot(sensordata[i, :, 1])
plt.subplot(3,1,3)
for i in range(nbatch): plt.plot(sensordata[i, :, 2])
plt.show()
100 gray tops is kind of boring. It would be better if they were colorful and different sizes!
rollout supports using different models for each rollout, so long as their dimensions are the same (i.e., floating point parameters can be different). Let's simulate 100 tippe tops with the same initial condition, but different sizes and colors.
Note: Strictly speaking, the models must have the same number of states, controls, degrees of freedom, and sensor outputs. The most common use case is multiple models of the same thing, with different parameter values.
# Make 100 tippe tops with different colors and sizes
nbatch = 100
spec = mujoco.MjSpec.from_string(tippe_top)
spec.lights[0].pos[2] = 2
models = []
for i in range(nbatch):
for geom in spec.geoms:
if geom.name in ['ball', 'stem', 'ballast']:
geom.rgba[:3] = np.random.rand(3)
if geom.name == 'stem':
stem_geom = geom
if geom.name == 'ball':
ball_geom = geom
# Save original geom size
stem_geom_size = np.copy(stem_geom.size)
ball_geom_size = np.copy(ball_geom.size)
# Scale geoms and compile model
size_scale = 0.4*np.random.rand(1) + 0.75
stem_geom.size *= size_scale
ball_geom.size *= size_scale
models.append(spec.compile())
# Restore original geom size
stem_geom.size = stem_geom_size
ball_geom.size = ball_geom_size
# Set the initial states of all the tops, placing them on a grid
top_data = mujoco.MjData(top_model)
mujoco.mj_resetDataKeyframe(top_model, top_data, 0)
initial_states = get_state(top_model, top_data, nbatch)
# index 0 is time, so x and y qpos values are at 1 and 2
initial_states[:, 1:3] = xy_grid(nbatch, ncols=10, spacing=.05)
# Run the rollout
start = time.time()
top_datas = [copy.copy(top_data) for _ in range(nthread)]
nstep = int(9 / top_model.opt.timestep)
state, sensordata = rollout.rollout(models, top_datas, initial_states,
nstep=nstep)
end = time.time()
# Render video
start_render = time.time()
framerate = 60
cam = mujoco.MjvCamera()
mujoco.mjv_defaultCamera(cam)
cam.distance = 0.2
cam.azimuth = 135
cam.elevation = -25
cam.lookat = [.2, -.2, 0.07]
models[0].vis.global_.fovy = 60
frames = render_many(models, top_data, state, framerate, camera=cam)
media.show_video(frames, fps=framerate)
end_render = time.time()
print(f'Rollout time {end-start:.1f} seconds')
print(f'Rendering time {end_render-start_render:.1f} seconds')
Because the models are now different, the measurements of the gyro sensor are not consistent even though the initial state for each rollout was the same.
plt.figure(figsize=(12, 8))
plt.subplot(3,1,1)
for i in range(nbatch): plt.plot(sensordata[i, :, 0])
plt.subplot(3,1,2)
for i in range(nbatch): plt.plot(sensordata[i, :, 1])
plt.subplot(3,1,3)
for i in range(nbatch): plt.plot(sensordata[i, :, 2])
plt.show()
Open loop controls can be passed to rollout via the control argument. If passed, nstep no longer needs to be specified as it can be inferred from the size of control.
Below we simulate 100 of the flailing humanoids from the tutorial notebook. Each humanoid uses a different control signal.
# Episode parameters.
duration = 3 # (seconds)
framerate = 120 # (Hz)
# Generate 100 different control sequences
nbatch = 100
nstep = int(duration / humanoid_model.opt.timestep)
times = np.linspace(0.0, duration, nstep)
ctrl_phase = 2 * np.pi * np.random.rand(nbatch, 1, humanoid_model.nu)
control = np.sin((2 * np.pi * times).reshape(nstep, 1) + ctrl_phase)
# Make initial states
humanoid_data = mujoco.MjData(humanoid_model)
humanoid_data.qvel[2] = 4 # Make the humanoid jump
initial_states = get_state(humanoid_model, humanoid_data, nbatch)
# index 0 is time, so x and y qpos values are at 1 and 2
initial_states[:, 1:3] = xy_grid(nbatch, ncols=10, spacing=1.0)
# Run the rollout
start = time.time()
humanoid_datas = [copy.copy(humanoid_data) for _ in range(nthread)]
state, _ = rollout.rollout(humanoid_model, humanoid_datas,
initial_states, control)
end = time.time()
# Render the rollout
start_render = time.time()
framerate = 120
cam = mujoco.MjvCamera()
mujoco.mjv_defaultCamera(cam)
cam.distance = 10
cam.azimuth = 45
cam.elevation = -15
cam.lookat = [0, 0, 0]
humanoid_model.vis.global_.fovy = 60
frames = render_many(humanoid_model, humanoid_data, state, framerate,
camera=cam, light_pos=[0, 0, 10])
media.show_video(frames, fps=framerate/2) # Show the video at half speed
end_render = time.time()
print(f'Rollout time {end-start:.1f} seconds')
print(f'Render time {end_render-start_render:.1f} seconds')
rollout's control_spec argument can be used to indicate control contains values for actuators, generalized forces, cartesian forces, mocap poses, and/or the activation/deactivation of equality constraints. Internally, this is managed through mj_setState and control_spec corresponds to mj_setState's spec argument.
Let's try applying cartesian forces in addition to the control inputs. This will make the humanoids look like they are being dragged while waving their limbs.
xfrc_size = mujoco.mj_stateSize(humanoid_model, mujoco.mjtState.mjSTATE_XFRC_APPLIED)
xfrc = np.zeros((nbatch, nstep, xfrc_size))
head_id = humanoid_model.body('head').id
# Apply a constant but different force to each model
force = np.random.normal(scale=150.0, size=(nbatch, 1, 3))
force[:,:,2] = 150 # Fixed vertical force
xfrc[:, :, 3*head_id:3*head_id+3] = force
control_xfrc = np.concatenate((control, xfrc), axis=2)
control_spec = mujoco.mjtState.mjSTATE_XFRC_APPLIED.value
start = time.time()
state, _ = rollout.rollout(humanoid_model, humanoid_datas,
initial_states, xfrc, control_spec=control_spec)
end = time.time()
start_render = time.time()
frames = render_many(humanoid_model, humanoid_data, state, framerate,
camera=cam, light_pos=[0, 0, 10])
media.show_video(frames, fps=framerate/2) # Show the video at half speed
end_render = time.time()
print(f'Rollout time {end-start:.1f} seconds')
print(f'Render time {end_render-start_render:.1f} seconds')
By default rollout performs many checks on the dimensions of its arguments. This it allows it to infer dimensions such as nbatch and nstep, tile arguments that were not fully specified, and allocate the returned state and sensordata arrays.
However, these check take time, particularly if state and sensordata are large or if there are many models and nstep is low. So advanced users may want to use the skip_checks=True argument in order to achieve additional performance.
If used, certain arguments become non-optional, and all signals must be fully defined (no implicit tiling). In particular:
model must be a list of length nbatchdata must be a list of length nthreadnstep must be specifiedinitial_state must be an array of shape nbatch x nstatecontrol is optional, but if passed must be an array of shape nbatch x nstep x ncontrolstate is optional, but must be passed if state is to be returned and must be of shape nbatch x nstep x nstatesensordata is optional, but must be passed if sensor data is to be returned and must be of shape nbatch x nstep x nsensordataAs an extreme example, we pass 10,000 humanoid models to rollout and simulate 1 step each with and without checks.
nbatch = 1000
nstep = [1, 10, 100, 500]
ntiming = 5
top_data = mujoco.MjData(top_model)
mujoco.mj_resetDataKeyframe(top_model, top_data, 0)
top_datas = [copy.copy(top_data) for _ in range(nthread)]
initial_state = get_state(top_model, top_data)
initial_state_tiled = get_state(top_model, top_data, nbatch)
# Note: state, sensordata array automatically allocated and return
def rollout_with_checks(nstep):
state, sensordata = rollout.rollout([top_model]*nbatch, top_datas, initial_state, nstep=nstep)
# Note: state, sensordata arrays have to be preallocated
state = None
sensordata = None
def rollout_skip_checks(nstep):
# Note initial state must be tiled
rollout.rollout([top_model]*nbatch, top_datas, initial_state_tiled, nstep=nstep,
state=state, sensordata=sensordata, skip_checks=True)
t_with_checks = benchmark(lambda x: rollout_with_checks(x), nstep, ntiming=ntiming)
t_skip_checks = benchmark(lambda x: rollout_skip_checks(x), nstep, ntiming=ntiming)
steps_per_second = (nbatch * np.array(nstep)) / np.array(t_with_checks)
steps_per_second_skip_checks = (nbatch * np.array(nstep)) / np.array(t_skip_checks)
plt.loglog(nstep, steps_per_second, label='with checks')
plt.loglog(nstep, steps_per_second_skip_checks, label='skip checks')
plt.ylabel('steps per second')
plt.xlabel('nstep')
ticker = matplotlib.ticker.FuncFormatter(lambda x, p: format(int(x), ','))
plt.gca().yaxis.set_minor_formatter(ticker)
plt.legend()
plt.grid(True, which="both", axis="both")
As expected, as nstep increases, the benefits of using skip checks fades quickly. However, at low nstep and high batch sizes, it can make a significant difference.
Notice that the version with checks can use the non-tiled initial_state, however the skip checks version must used the tiled version, initial_state_tiled.
Rollout class)The rollout module provided the class Rollout in addition to the method rollout. The class Rollout is designed allow safe reuse of the internally managed thread pool.
Reuse can speed things up considerably when rollouts are short. Let's find out how the speedup changes for the tippe top model by rolling it out with increasing numbers of steps.
nbatch = 100
nsteps = [2**i for i in [2, 3, 4, 5, 6, 7]]
ntiming = 5
top_data = mujoco.MjData(top_model)
mujoco.mj_resetDataKeyframe(top_model, top_data, 0)
top_datas = [copy.copy(top_data) for _ in range(nthread)]
initial_states = get_state(top_model, top_data, nbatch)
def rollout_method(nstep):
for i in range(20):
rollout.rollout(top_model, top_datas, initial_states, nstep=nstep)
def rollout_class(nstep):
with rollout.Rollout(nthread=nthread) as rollout_:
for i in range(20):
rollout_.rollout(top_model, top_datas, initial_states, nstep=nstep)
t_method = benchmark(lambda x: rollout_method(x), nsteps, ntiming)
t_class = benchmark(lambda x: rollout_class(x), nsteps, ntiming)
plt.loglog(nsteps, nbatch * np.array(nsteps) / t_method, label='recreating threadpools')
plt.loglog(nsteps, nbatch * np.array(nsteps) / t_class, label='reusing threadpool')
plt.xlabel('nstep')
plt.ylabel('steps per second')
ticker = matplotlib.ticker.FuncFormatter(lambda x, p: format(int(x), ','))
plt.gca().yaxis.set_minor_formatter(ticker)
plt.legend()
plt.grid(True, which="both", axis="both")
rollout method)rollout will create and reuse a persistent threadpool by passing persistent_pool=True. However there are some caveats.
First, because rollout is a function and does not know when the user is done calling it, the threadpool pool needs to be shutdown manually like this:
nbatch = 1000
nstep = 1
top_data = mujoco.MjData(top_model)
mujoco.mj_resetDataKeyframe(top_model, top_data, 0)
top_datas = [copy.copy(top_data) for _ in range(nthread)]
initial_states = get_state(top_model, top_data, nbatch)
rollout.rollout(top_model, top_datas, initial_states, nstep=nstep, persistent_pool=True) # Creates a pool
rollout.rollout(top_model, top_datas, initial_states, nstep=nstep, persistent_pool=True) # Reuses the previously created pool
rollout.shutdown_persistent_pool() # Shutdown the pool manually when finished
Second, if rollout reuses the same threadpool between calls, it is no longer safe to call rollout from multiple threads. For example the following is not allowed (the offending lines are commented out to avoid crashing the interpreter):
thread1 = threading.Thread(target=lambda: rollout.rollout(top_model, top_datas, initial_states, nstep=nstep, persistent_pool=True))
thread2 = threading.Thread(target=lambda: rollout.rollout(top_model, top_datas, initial_states, nstep=nstep, persistent_pool=True))
thread1.start()
#thread2.start() # Do not do this! rollout will be using the same persistent threadpool from two threads and may crash the interpreter
thread1.join()
#thread2.join()
rollout.shutdown_persistent_pool()
To minimize communication overhead, rollout distributes rollouts to threads in groups of rollouts called chunks. By default, max(1, 0.1 * (nbatch / nthread)) rollouts are assigned to each chunk. While this chunking rule works well for most workloads it is not always optimal, especially when doing short rollouts with small models.
Below we plot the steps per second versus chunk_size when running 1000 hoppers for 1 step each. In his case, the default chunk_size turns out to be quite a bit slower than using an increased chunk size.
nbatch = 100
nstep = 1
ntiming = 20
# Load model
hopper_model = mujoco.MjModel.from_xml_path(hopper_path)
hopper_data = mujoco.MjData(hopper_model)
hopper_datas = [copy.copy(hopper_data) for _ in range(nthread)]
# Get initial states
initial_states = get_state(hopper_model, hopper_data, nbatch)
def rollout_chunk_size(chunk_size=None):
rollout.rollout(hopper_model, hopper_datas, initial_states, nstep=nstep, chunk_size=chunk_size)
# Rollout with different chunk sizes
default_chunk_size = int(max(1.0, 0.1 * nbatch / nthread))
chunk_sizes = sorted([1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, default_chunk_size])
t_chunk_size = benchmark(lambda x: rollout_chunk_size(x), chunk_sizes, ntiming=ntiming)
# Get optimal chunk size
steps_per_second = nbatch * nstep / t_chunk_size
default_index = [i for i, c in enumerate(chunk_sizes) if c == default_chunk_size][0]
optimal_index = np.argmax(steps_per_second)
plt.loglog(chunk_sizes, steps_per_second, color='b')
plt.plot(chunk_sizes[default_index], steps_per_second[default_index], marker='o', color='r', label='default chunk size')
plt.plot(chunk_sizes[optimal_index], steps_per_second[optimal_index], marker='o', color='g', label='optimal chunk size')
plt.ylabel('steps per second')
plt.xlabel('chunk size')
ticker = matplotlib.ticker.FuncFormatter(lambda x, p: format(int(x), ','))
plt.gca().yaxis.set_minor_formatter(ticker)
plt.legend()
plt.grid(True, which="both", axis="both")
print(f'default chunk size: {default_chunk_size} \t steps per second: {steps_per_second[default_index]:0.1f}')
print(f'optimal chunk size: {chunk_sizes[optimal_index]} \t steps per second: {steps_per_second[optimal_index]:0.1f}')
The initial_warmstart parameter can be used to warmstart the constraint solver as described in the computation chapter of the documentation. This can be useful when rolling out models in chunks of steps. Without warmstarting, chaotic systems involving multi-body contact may diverge.
Below we demonstrate this with the tippe top model where the contact solver was changed to CG. This makes the contact force calculation a less repeatable than if the default, Newton's method, were used and allows demonstrating the benefits of warmstarting.
The simulation is run three times. Once with a 6000 step rollout, once with 100 chunks of 60 steps with warmstarting, and once more in 100 chunks of 60 steps without warmstarting.
top_model_cg = copy.copy(top_model)
# Change to CG solver, the Newton solver converges too well for
# warmstarting to have an appreciable effect
top_model_cg.opt.solver = mujoco.mjtSolver.mjSOL_CG
chunks = 100
steps_per_chunk = 60
nstep = steps_per_chunk*chunks
# Get initial states
top_data_cg = mujoco.MjData(top_model_cg)
mujoco.mj_resetDataKeyframe(top_model_cg, top_data_cg, 0)
initial_state = get_state(top_model_cg, top_data_cg)
start = time.time()
# Rollout with nstep steps
state_all, _ = rollout.rollout(top_model_cg, top_data_cg, initial_state, nstep=nstep)
# Rollout in chunks with warmstarting
state_chunks = []
state_chunk, _ = rollout.rollout(top_model_cg, top_data_cg, initial_state, nstep=steps_per_chunk)
state_chunks.append(state_chunk)
for _ in range(chunks-1):
state_chunk, _ = rollout.rollout(top_model_cg, top_data_cg, state_chunks[-1][0, -1, :],
nstep=steps_per_chunk, initial_warmstart=top_data_cg.qacc_warmstart)
state_chunks.append(state_chunk)
state_all_chunked_warmstart = np.concatenate(state_chunks, axis=1)
# Rollout in chunks without warmstarting
state_chunks = []
state_chunk, _ = rollout.rollout(top_model_cg, top_data_cg, initial_state, nstep=steps_per_chunk)
state_chunks.append(state_chunk)
first_warmstart = None
for i in range(chunks-1):
state_chunk, _ = rollout.rollout(top_model_cg, top_data_cg, state_chunks[-1][0, -1, :], nstep=steps_per_chunk)
state_chunks.append(state_chunk)
state_all_chunked = np.concatenate(state_chunks, axis=1)
end = time.time()
# Render the rollouts
start_render = time.time()
framerate = 60
state_render = np.concatenate((state_all, state_all_chunked, state_all_chunked_warmstart), axis=0)
camera = 'distant'
frames1 = render_many(top_model_cg, top_data_cg, state_all, framerate, shape=(240, 320), transparent=False, camera=camera)
frames2 = render_many(top_model_cg, top_data_cg, state_all_chunked_warmstart, framerate, shape=(240, 320), transparent=False, camera=camera)
frames3 = render_many(top_model_cg, top_data_cg, state_all_chunked, framerate, shape=(240, 320), transparent=False, camera=camera)
media.show_video(np.concatenate((frames1, frames2, frames3), axis=2))
end_render = time.time()
print(f'Rollout took {end-start:.1f} seconds')
print(f'Rendering took {end_render-start_render:.1f} seconds')
As expected, the middle animation (with warmstarting) matches the continuous rollout on the left. However, the model that did not use warmstarting diverged.
The rollout.rollout function in the mujoco Python library runs batches of simulations for a fixed number steps. It can run in single or multi-threaded modes. The speedup over pure Python is significant because rollout can be easily configured to use multithreading.
To show the speedup, we will run benchmarks with the "tippe top", "humanoid", and "humanoid100" models.
rolloutThe benchmark runs the three models with varying batch and step counts.
The Python code for nbatch rollouts of nstep steps is:
def python_rollout(model, data, nbatch, nstep):
for i in range(nbatch):
for i in range(nstep):
mujoco.mj_step(model, data)
To run nbatch rollouts with rollout, we need to make an array of nbatch initial states to start the rollouts from.
Additionally, to use rollout's parallelism, we must pass one MjData per thread.
The resulting rollout call parameterized by nbatch, nstep, and nthread is:
def nthread_rollout(model, data, nbatch, nstep, nthread, rollout_):
rollout_.rollout([model]*nbatch,
[copy.copy(data) for _ in range(nthread)], # Create one MjData per thread
np.tile(get_state(model, data), (nbatch, 1)), # Tile the initial condition nbatch times
nstep=nstep,
skip_checks=True)
Next, we benchmark the Python loop and rollout in both single threaded and multithreaded modes. The three benchmarks take about 2.5 minutes in total to run in total on an AMD 5800X3D.
#@title Benchmarking utilities
top_model = mujoco.MjModel.from_xml_string(tippe_top)
def init_top(model):
data = mujoco.MjData(model)
# Set to the state to a spinning top (keyframe 0)
mujoco.mj_resetDataKeyframe(model, data, 0)
return data
# Create and initialize humanoid model
# Step for 2 seconds to get a stable set of contacts to benchmark
humanoid_model = mujoco.MjModel.from_xml_path(humanoid_path)
humanoid_data = mujoco.MjData(humanoid_model)
humanoid_data.qvel[2] = 4 # Make the humanoid jump
while humanoid_data.time < 2.0:
mujoco.mj_step(humanoid_model, humanoid_data)
humanoid_initial_state = get_state(humanoid_model, humanoid_data)
def init_humanoid(model):
data = mujoco.MjData(model)
mujoco.mj_setState(model, data, humanoid_initial_state.flatten(),
mujoco.mjtState.mjSTATE_FULLPHYSICS)
return data
# Create and initialize humanoid100 model
# Step for 4 seconds to get a stable set of contacts to benchmark
humanoid100_model = mujoco.MjModel.from_xml_path(humanoid100_path)
humanoid100_data = mujoco.MjData(humanoid100_model)
while humanoid100_data.time < 4.0:
mujoco.mj_step(humanoid100_model, humanoid100_data)
humanoid100_initial_state = get_state(humanoid100_model, humanoid100_data)
def init_humanoid100(model):
data = mujoco.MjData(model)
mujoco.mj_setState(model, data, humanoid100_initial_state.flatten(),
mujoco.mjtState.mjSTATE_FULLPHYSICS)
return data
def benchmark_rollout(model, init_model, nbatch, nstep, nominal_nbatch, nominal_nstep, ntiming=1):
print('Benchmarking pure python', end='\r')
start = time.time()
t_python_nbatch = benchmark(lambda x, data: python_rollout(model, data, x, nominal_nstep), nbatch, ntiming,
f_init=lambda x: init_model(model))
t_python_nstep = benchmark(lambda x, data: python_rollout(model, data, nominal_nbatch, x), nstep, ntiming,
f_init=lambda x: init_model(model))
end = time.time()
print(f'Benchmarking pure python took {end-start:0.1f} seconds')
print('Benchmarking single threaded rollout', end='\r')
with rollout.Rollout(nthread=0) as rollout_:
start = time.time()
t_rollout_single_nbatch = benchmark(lambda x, data: nthread_rollout(model, data, x, nominal_nstep, nthread=1, rollout_=rollout_),
nbatch, ntiming,
f_init=lambda x: init_model(model))
t_rollout_single_nstep = benchmark(lambda x, data: nthread_rollout(model, data, nominal_nbatch, x, nthread=1, rollout_=rollout_),
nstep, ntiming, f_init=lambda x: init_model(model))
end = time.time()
print(f'Benchmarking single threaded rollout took {end-start:0.1f} seconds')
print(f'Benchmarking multithreaded rollout using {nthread} threads', end='\r')
with rollout.Rollout(nthread=nthread) as rollout_:
start = time.time()
t_rollout_multi_nbatch = benchmark(lambda x, data: nthread_rollout(model, data, x, nominal_nstep, nthread, rollout_=rollout_),
nbatch, ntiming, f_init=lambda x: init_model(model))
t_rollout_multi_nstep = benchmark(lambda x, data: nthread_rollout(model, data, nominal_nbatch, x, nthread, rollout_=rollout_),
nstep, ntiming, f_init=lambda x: init_model(model))
end = time.time()
print(f'Benchmarking multithreaded rollout using {nthread} threads took {end-start:0.1f} seconds')
return (t_python_nbatch, t_rollout_single_nbatch, t_rollout_multi_nbatch,
t_python_nstep, t_rollout_single_nstep, t_rollout_multi_nstep)
def plot_benchmark(results, nbatch, nstep, nominal_nbatch, nominal_nstep, title):
(t_python_nbatch, t_rollout_single_nbatch, t_rollout_multi_nbatch,
t_python_nstep, t_rollout_single_nstep, t_rollout_multi_nstep) = results
width = 0.25
x = np.array([i for i in range(len(nbatch))])
ticker = matplotlib.ticker.EngFormatter(unit='')
fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
steps_per_t = np.array(nbatch) * nominal_nstep
steps_per_t_python = steps_per_t / t_python_nbatch
steps_per_t_single = steps_per_t / t_rollout_single_nbatch
steps_per_t_multi = steps_per_t / t_rollout_multi_nbatch
ax1.bar(x + 0*width, steps_per_t_python, width=width, label='python')
ax1.bar(x + 1*width, steps_per_t_single, width=width, label='rollout single threaded')
ax1.bar(x + 2*width, steps_per_t_multi, width=width, label='rollout multithreaded')
ax1.set_xticks(x + width, nbatch)
ax1.yaxis.set_major_formatter(ticker)
ax1.grid()
ax1.set_axisbelow(True)
ax1.set_xlabel('nbatch')
ax1.set_ylabel('steps per second')
ax1.set_title(f'nbatch varied, nstep = {nominal_nstep}')
x = np.array([i for i in range(len(nstep))])
steps_per_t = np.array(nstep) * nominal_nbatch
steps_per_t_python = steps_per_t / t_python_nstep
steps_per_t_single = steps_per_t / t_rollout_single_nstep
steps_per_t_multi = steps_per_t / t_rollout_multi_nstep
ax2.bar(x + 0*width, steps_per_t_python, width=width, label='python')
ax2.bar(x + 1*width, steps_per_t_single, width=width, label='rollout single threaded')
ax2.bar(x + 2*width, steps_per_t_multi, width=width, label='rollout multithreaded')
ax2.set_xticks(x + width, nstep)
ax2.yaxis.set_major_formatter(ticker)
ax2.grid()
ax2.set_axisbelow(True)
ax2.set_xlabel('nstep')
ax2.set_title(f'nstep varied, nbatch = {nominal_nbatch}')
ax1.legend(loc=(0.03, 0.8))
fig.set_size_inches(10, 5)
plt.suptitle(title)
plt.tight_layout()
nominal_nbatch = 256 # Batch size to use when testing different nstep
nominal_nstep = 5 # Step count to use when testing different nbatch
nbatch = [1, 256, 2048, 8192]
nstep = [1, 10, 100, 1000]
top_benchmark_results = benchmark_rollout(top_model, init_top,
nbatch, nstep,
nominal_nbatch, nominal_nstep)
plot_benchmark(top_benchmark_results, nbatch, nstep,
nominal_nbatch, nominal_nstep,
title='Tippe Top')
nominal_nbatch = 256 # Batch size to use when testing different nstep
nominal_nstep = 5 # Step count to use when testing different nbatch
nbatch = [1, 256, 2048, 8192] # Batch sizes to benchmark
nstep = [1, 10, 100, 1000] # Step counts to benchmark
humanoid_benchmark_results = benchmark_rollout(humanoid_model, init_humanoid,
nbatch, nstep,
nominal_nbatch, nominal_nstep)
plot_benchmark(humanoid_benchmark_results, nbatch, nstep,
nominal_nbatch, nominal_nstep,
title='Humanoid')
nominal_nbatch = 128 # Batch size to use when testing different nstep
nominal_nstep = 5 # Step count to use when testing different nbatch
nbatch = [1, 64, 128, 256] # Batch sizes to benchmark
nstep = [1, 10, 100, 1000] # Step counts to benchmark
humanoid100_benchmark_results = benchmark_rollout(
humanoid100_model,
init_humanoid100,
nbatch,
nstep,
nominal_nbatch,
nominal_nstep,
)
plot_benchmark(humanoid100_benchmark_results, nbatch, nstep,
nominal_nbatch, nominal_nstep,
title='Humanoid100')
rolloutNext we will benchmark rollout and MJX using the tippe top and humanoid models (humanoid100 is not supported by MJX).
The next two benchmarks take about 16.5 minutes total on an AMD 5800X3D and an NVIDIA 4090. Most of the time is spent JIT compiling the MJX functions. The JIT functions are cached so that subsequent runs of the benchmark run much faster.
Note: MJX is most useful when coupled with something else that runs best on a GPU, like a neural network. Without any such additional workload, CPU based simulation will sometimes be faster, especially when using less than state-of-the-art GPUs.
#@title MJX helper functions
def init_mjx_batch(model, init_model, nbatch, nstep, skip_jit=False):
data = init_model(model)
# Make MJX versions of model and data
mjx_model = mjx.put_model(model)
mjx_data = mjx.put_data(model, data)
batch = jax.vmap(lambda x: mjx_data)(jp.array(list(range(nbatch))))
jax.block_until_ready(batch)
if not skip_jit:
start = time.time()
jit_step = jax.vmap(mjx.step, in_axes=(None, 0))
def unroll(d, _):
d = jit_step(mjx_model, d)
return d, None
jit_unroll = jax.jit(lambda d: jax.lax.scan(unroll, d, None, length=nstep, unroll=4)[0])
jit_unroll = jit_unroll.lower(batch).compile()
end = time.time()
jit_time = end - start
else:
jit_unroll = None
jit_time = 0.0
return mjx_model, mjx_data, jit_unroll, batch, jit_time
def mjx_rollout(batch, jit_unroll):
batch = jit_unroll(batch)
jax.block_until_ready(batch)
def benchmark_mjx(model, init_model, nbatch, nstep, nominal_nbatch, nominal_nstep, ntiming=1, jit_unroll_cache=None):
print(f'Benchmarking multithreaded rollout using {nthread} threads', end="\r")
with rollout.Rollout(nthread=nthread) as rollout_:
start = time.time()
t_rollout_multi_nbatch = benchmark(lambda x, data: nthread_rollout(model, data, x, nominal_nstep, nthread, rollout_),
nbatch, ntiming, f_init=lambda x: init_model(model))
t_rollout_multi_nstep = benchmark(lambda x, data: nthread_rollout(model, data, nominal_nbatch, x, nthread, rollout_),
nstep, ntiming, f_init=lambda x: init_model(model))
end = time.time()
print(f'Benchmarking multithreaded rollout using {nthread} threads took {end-start:0.1f} seconds')
print('Running JIT for MJX', end='\r')
total_jit = 0.0
if jit_unroll_cache is None:
jit_unroll_cache = {}
if f'nbatch_{nominal_nstep}' not in jit_unroll_cache:
jit_unroll_cache[f'nbatch_{nominal_nstep}'] = {}
if f'nstep_{nominal_nbatch}' not in jit_unroll_cache:
jit_unroll_cache[f'nstep_{nominal_nbatch}'] = {}
for n in nbatch:
if n not in jit_unroll_cache[f'nbatch_{nominal_nstep}']:
_, _, jit_unroll_cache[f'nbatch_{nominal_nstep}'][n], _, jit_time = init_mjx_batch(model, init_model, n, nominal_nstep)
total_jit += jit_time
for n in nstep:
if n not in jit_unroll_cache[f'nstep_{nominal_nbatch}']:
_, _, jit_unroll_cache[f'nstep_{nominal_nbatch}'][n], _, jit_time = init_mjx_batch(model, init_model, nominal_nbatch, n)
total_jit += jit_time
print(f'Running JIT for MJX took {total_jit:0.1f} seconds')
print('Benchmarking MJX', end='\r')
start = time.time()
t_mjx_nbatch = benchmark(lambda x, x_init: mjx_rollout(x_init[3], jit_unroll_cache[f'nbatch_{nominal_nstep}'][x]),
nbatch, ntiming, f_init=lambda x: init_mjx_batch(model, init_model, x, nominal_nstep, skip_jit=True))
t_mjx_nstep = benchmark(lambda x, x_init: mjx_rollout(x_init[3], jit_unroll_cache[f'nstep_{nominal_nbatch}'][x]),
nstep, ntiming, f_init=lambda x: init_mjx_batch(model, init_model, nominal_nbatch, x, skip_jit=True))
end = time.time()
print(f'Benchmarking MJX took {end-start:0.1f} seconds')
return t_rollout_multi_nbatch, t_rollout_multi_nstep, t_mjx_nbatch, t_mjx_nstep
def plot_mjx_benchmark(results, nbatch, nstep, nominal_nbatch, nominal_nstep, title):
t_rollout_multi_nbatch, t_rollout_multi_nstep, t_mjx_nbatch, t_mjx_nstep = results
width = 0.333
x = np.array([i for i in range(len(nbatch))])
ticker = matplotlib.ticker.EngFormatter(unit='')
fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
steps_per_t = np.array(nbatch) * nominal_nstep
steps_per_t_mjx = steps_per_t / t_mjx_nbatch
steps_per_t_multi = steps_per_t / t_rollout_multi_nbatch
ax1.bar(x + 0*width, steps_per_t_mjx, width=width, label='mjx')
ax1.bar(x + 1*width, steps_per_t_multi, width=width, label='rollout multithreaded')
ax1.set_xticks(x + width / 2, nbatch)
ax1.yaxis.set_major_formatter(ticker)
ax1.grid()
ax1.set_xlabel('nbatch')
ax1.set_ylabel('steps per second')
ax1.set_title(f'nbatch varied, nstep = {nominal_nstep}')
x = np.array([i for i in range(len(nstep))])
steps_per_t = np.array(nstep) * nominal_nbatch
steps_per_t_mjx = steps_per_t / t_mjx_nstep
steps_per_t_multi = steps_per_t / t_rollout_multi_nstep
ax2.bar(x + 0*width, steps_per_t_mjx, width=width, label='mjx')
ax2.bar(x + 1*width, steps_per_t_multi, width=width, label='rollout multithreaded')
ax2.set_xticks(x + width / 2, nstep)
ax2.yaxis.set_major_formatter(ticker)
ax2.grid()
ax2.set_xlabel('nstep')
ax2.set_title(f'nstep varied, nbatch = {nominal_nbatch}')
ax2.legend(loc=(1.04, 0.0))
fig.set_size_inches(10, 4)
plt.suptitle(title)
plt.tight_layout()
# Caches for jit_step functions, they take a long time to compile
top_jit_unroll_cache = {}
humanoid_jit_unroll_cache = {}
nominal_nbatch = 16384 # Batch size to use when testing different nstep
nominal_nstep = 5 # Step count to use when testing different nbatch
nbatch = [4096, 16384, 65536, 131072] # Batch sizes to benchmark
nstep = [1, 10, 100, 200] # Step counts to benchmark
mjx_top_results = benchmark_mjx(top_model, init_top, nbatch, nstep, nominal_nbatch, nominal_nstep,
jit_unroll_cache=top_jit_unroll_cache)
plot_mjx_benchmark(mjx_top_results, nbatch, nstep, nominal_nbatch, nominal_nstep, title='MJX Tippe Top')
nominal_nbatch = 4096 # Batch size to use when testing different nstep
nominal_nstep = 5 # Step count to use when testing different nbatch
nbatch = [1024, 4096, 16384, 32768] # Batch sizes to benchmark
nstep = [1, 10, 100, 200] # Step counts to benchmark
mjx_humanoid_results = benchmark_mjx(humanoid_model, init_humanoid, nbatch, nstep, nominal_nbatch, nominal_nstep,
jit_unroll_cache=humanoid_jit_unroll_cache)
plot_mjx_benchmark(mjx_humanoid_results, nbatch, nstep, nominal_nbatch, nominal_nstep, title='MJX Humanoid')
The MJX documentation contains a chart comparing the speed of native MuJoCo vs MJX on a variety of devices.
Here we will produce a similar plot to compare MJX and with rollout. On a 5800X3D and 4090 the benchmark takes about 16.5 minutes to run.
Note: These results are not directly comparable since with the plot in the documentation because, in particular, the batch size was reduced from 8192 to 4096 in order to fit the batch on a 4090.
max_humanoids = 10
nbatch = 8192 // 2 # The original benchmark ran with a batch size of 8192, but on a 4090 we can only fit about 4096 humanoids
nstep = 200
jit_step = jax.vmap(mjx.step, in_axes=(None, 0))
t_rollout = []
t_mjx = []
for i in range(1, max_humanoids+1):
print(f'Running benchmark on {i} humanoids')
nhumanoid_model = mujoco.MjModel.from_xml_path(
f'mujoco/mjx/mujoco/mjx/test_data/humanoid/{i:02d}_humanoids.xml'
)
nhumanoid_data = mujoco.MjData(nhumanoid_model)
mjx_model = mjx.put_model(nhumanoid_model)
mjx_data = mjx.put_data(nhumanoid_model, nhumanoid_data)
batch = jax.vmap(lambda x: mjx_data)(jp.array(list(range(nbatch))))
jax.block_until_ready(batch)
with rollout.Rollout(nthread=nthread) as rollout_:
initial_state = get_state(nhumanoid_model, nhumanoid_data, nbatch)
start = time.perf_counter()
rollout_.rollout([nhumanoid_model]*nbatch,
[copy.copy(nhumanoid_data) for _ in range(nthread)],
initial_state=initial_state,
nstep=nstep, skip_checks=True)
end = time.perf_counter()
t_rollout.append(end-start)
# Trigger JIT for model/batch so as not to include JIT time in benchmarking information
def unroll(d, _):
d = jit_step(mjx_model, d)
return d, None
jit_unroll = jax.jit(lambda d: jax.lax.scan(unroll, d, None, length=nstep, unroll=4)[0])
jit_unroll = jit_unroll.lower(batch).compile()
start = time.perf_counter()
jit_unroll(batch)
jax.block_until_ready(batch)
end = time.perf_counter()
t_mjx.append(end-start)
#@title Plot MJX nhumanoid benchmark
def plot_mjx_nhumanoid_benchmark(t_rollout, t_mjx, nbatch, nstep, max_humanoids):
nhumanoids = [i for i in range(1, max_humanoids+1)]
width = 0.333
x = np.array([i for i in range(len(nhumanoids))])
ticker = matplotlib.ticker.EngFormatter(unit='')
fig, ax1 = plt.subplots(1, 1, sharey=True)
steps_per_t = nbatch * nstep
steps_per_t_mjx = steps_per_t / np.array(t_mjx)
steps_per_t_multi = steps_per_t / np.array(t_rollout)
ax1.bar(x + 0*width, steps_per_t_mjx, width=width, label='mjx')
ax1.bar(x + 1*width, steps_per_t_multi, width=width, label='rollout multithreaded')
ax1.set_xticks(x + width / 2, nhumanoids)
ax1.yaxis.set_major_formatter(ticker)
ax1.set_yscale('log')
ax1.grid()
ax1.set_xlabel('number of humanoids')
ax1.set_ylabel('steps per second')
ax1.set_title(f'nhumanoids varied, nbatch = {nbatch}, nstep = {nstep}')
ax1.legend(loc=(1.04, 0.0))
fig.set_size_inches(8, 4)
plt.tight_layout()
plot_mjx_nhumanoid_benchmark(t_rollout, t_mjx, nbatch, nstep, max_humanoids)