cloud_tpu_colabs/Pmap_Cookbook.ipynb
This notebook is an introduction to writing single-program multiple-data (SPMD) programs in JAX, and executing them synchronously in parallel on multiple devices, such as multiple GPUs or multiple TPU cores. The SPMD model is useful for computations like training neural networks with synchronous gradient descent algorithms, and can be used for data-parallel as well as model-parallel computations.
Note: To run this notebook with any parallelism, you'll need multiple XLA devices available, e.g. with a multi-GPU machine, a Colab TPU, a Google Cloud TPU or a Kaggle TPU VM.
The code in this notebook is simple. For an example of how to use these tools to do data-parallel neural network training, check out the SPMD MNIST example or the much more capable Trax library.
import jax.numpy as jnp
A basic starting point is expressing parallel maps with pmap:
from jax import pmap
result = pmap(lambda x: x ** 2)(jnp.arange(7))
print(result)
In terms of what values are computed, pmap is similar to vmap in that it transforms a function to map over an array axis:
from jax import vmap
x = jnp.array([1., 2., 3.])
y = jnp.array([2., 4., 6.])
print(vmap(jnp.add)(x, y))
print(pmap(jnp.add)(x, y))
But pmap and vmap differ in how those values are computed: where vmap vectorizes a function by adding a batch dimension to every primitive operation in the function (e.g. turning matrix-vector multiplies into matrix-matrix multiplies), pmap instead replicates the function and executes each replica on its own XLA device in parallel.
from jax import make_jaxpr
def f(x, y):
a = jnp.dot(x, y)
b = jnp.tanh(a)
return b
xs = jnp.ones((8, 2, 3))
ys = jnp.ones((8, 3, 4))
print("f jaxpr")
print(make_jaxpr(f)(xs[0], ys[0]))
print("vmap(f) jaxpr")
print(make_jaxpr(vmap(f))(xs, ys))
print("pmap(f) jaxpr")
print(make_jaxpr(pmap(f))(xs, ys))
Notice that applying vmap(f) to these arguments leads to a dot_general to express the batch matrix multiplication in a single primitive, while applying pmap(f) instead leads to a primitive that calls replicas of the original f in parallel.
An important constraint with using pmap is that
the mapped axis size must be less than or equal to the number of XLA devices available (and for nested pmap functions, the product of the mapped axis sizes must be less than or equal to the number of XLA devices).
You can use the output of a pmap function just like any other value:
y = pmap(lambda x: x ** 2)(jnp.arange(8))
z = y / 2
print(z)
import matplotlib.pyplot as plt
plt.plot(y)
But while the output here acts just like a NumPy ndarray, if you look closely it has a different type:
y
A sharded Array is effectively an ndarray subclass, but it's stored in pieces spread across the memory of multiple devices. Results from pmap functions are left sharded in device memory so that they can be operated on by subsequent pmap functions without moving data around, at least in some cases. But these results logically appear just like a single array.
When you call a non-pmap function on an Array, like a standard jax.numpy function, communication happens behind the scenes to bring the values to one device (or back to the host in the case of the matplotlib function above):
y / 2
import numpy as np
np.sin(y)
Thinking about device memory is important to maximize performance by avoiding data transfers, but you can always fall back to treating arraylike values as (read-only) NumPy ndarrays and your code will still work.
Here's another example of a pure map which makes better use of our multiple-accelerator resources. We can generate several large random matrices in parallel, then perform parallel batch matrix multiplication without any cross-device movement of the large matrix data:
from jax import random
# create 8 random keys
keys = random.split(random.key(0), 8)
# create a 5000 x 6000 matrix on each device by mapping over keys
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)
# the stack of matrices is represented logically as a single array
mats.shape
# run a local matmul on each device in parallel (no data transfer)
result = pmap(lambda x: jnp.dot(x, x.T))(mats)
result.shape
# compute the mean on each device in parallel and print the results
print(pmap(jnp.mean)(result))
In this example, the large matrices never had to be moved between devices or back to the host; only one scalar per device was pulled back to the host.
In addition to expressing pure maps, where no communication happens between the replicated functions, with pmap you can also use special collective communication operations.
One canonical example of a collective, implemented on both GPU and TPU, is an all-reduce sum like lax.psum:
from jax import lax
normalize = lambda x: x / lax.psum(x, axis_name='i')
result = pmap(normalize, axis_name='i')(jnp.arange(4.))
print(result)
To use a collective operation like lax.psum, you need to supply an axis_name argument to pmap. The axis_name argument associates a name to the mapped axis so that collective operations can refer to it.
Another way to write this same code is to use pmap as a decorator:
from functools import partial
@partial(pmap, axis_name='i')
def normalize(x):
return x / lax.psum(x, 'i')
print(normalize(jnp.arange(4.)))
Axis names are also important for nested use of pmap, where collectives can be applied to distinct mapped axes:
@partial(pmap, axis_name='rows')
@partial(pmap, axis_name='cols')
def f(x):
row_normed = x / lax.psum(x, 'rows')
col_normed = x / lax.psum(x, 'cols')
doubly_normed = x / lax.psum(x, ('rows', 'cols'))
return row_normed, col_normed, doubly_normed
x = jnp.arange(8.).reshape((4, 2))
a, b, c = f(x)
print(a)
print(a.sum(0))
When writing nested pmap functions in the decorator style, axis names are resolved according to lexical scoping.
Check the JAX reference documentation for a complete list of the parallel operators. More are being added!
Here's how to use lax.ppermute to implement a simple halo exchange for a Rule 30 simulation:
from jax._src import xla_bridge
device_count = jax.device_count()
def send_right(x, axis_name):
left_perm = [(i, (i + 1) % device_count) for i in range(device_count)]
return lax.ppermute(x, perm=left_perm, axis_name=axis_name)
def send_left(x, axis_name):
left_perm = [((i + 1) % device_count, i) for i in range(device_count)]
return lax.ppermute(x, perm=left_perm, axis_name=axis_name)
def update_board(board):
left = board[:-2]
right = board[2:]
center = board[1:-1]
return lax.bitwise_xor(left, lax.bitwise_or(center, right))
@partial(pmap, axis_name='i')
def step(board_slice):
left, right = board_slice[:1], board_slice[-1:]
right, left = send_left(left, 'i'), send_right(right, 'i')
enlarged_board_slice = jnp.concatenate([left, board_slice, right])
return update_board(enlarged_board_slice)
def print_board(board):
print(''.join('*' if x else ' ' for x in board.ravel()))
board = np.zeros(40, dtype=bool)
board[board.shape[0] // 2] = True
reshaped_board = board.reshape((device_count, -1))
print_board(reshaped_board)
for _ in range(20):
reshaped_board = step(reshaped_board)
print_board(reshaped_board)
As with all things in JAX, you should expect pmap to compose with other transformations, including differentiation.
from jax import grad
@pmap
def f(x):
y = jnp.sin(x)
@pmap
def g(z):
return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
return grad(lambda w: jnp.sum(g(w)))(x)
f(x)
grad(lambda x: jnp.sum(f(x)))(x)
When reverse-mode differentiating a pmap function (e.g. with grad), the backward pass of the computation is parallelized just like the forward-pass.