Back to jax

Distributed arrays and automatic parallelization

docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb

0.3.2515.7 KB
Original Source

Distributed arrays and automatic parallelization

<!--* freshness: { reviewed: '2024-04-16' } *-->

This tutorial discusses parallelism via jax.Array, the unified array object model available in JAX v0.4.1 and newer. See {doc}../the-training-cookbook for a real-world machine learning training example that uses this API.

python
from typing import Optional

import numpy as np

import jax
import jax.numpy as jnp

⚠️ WARNING: The notebook requires 8 devices to run.

python
if len(jax.local_devices()) < 8:
  raise Exception("Notebook requires 8 devices to run")

Intro and a quick example

By reading this tutorial notebook, you'll learn about jax.Array, a unified datatype for representing arrays, even with physical storage spanning multiple devices. You'll also learn about how using jax.Arrays together with jax.jit can provide automatic compiler-based parallelization.

Before we think step by step, here's a quick example. First, we'll create a jax.Array sharded across multiple devices:

python
from jax.sharding import PartitionSpec as P, NamedSharding
python
# Create a Sharding object to distribute a value across devices:
mesh = jax.make_mesh((4, 2), ('x', 'y'))
python
# Create an array of random values:
x = jax.random.normal(jax.random.key(0), (8192, 8192))
# and use jax.device_put to distribute it across devices:
y = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))
jax.debug.visualize_array_sharding(y)

Next, we'll apply a computation to it and visualize how the result values are stored across multiple devices too:

python
z = jnp.sin(y)
jax.debug.visualize_array_sharding(z)

The evaluation of the jnp.sin application was automatically parallelized across the devices on which the input values (and output values) are stored:

python
# `x` is present on a single device
%timeit -n 5 -r 5 jnp.sin(x).block_until_ready()
python
# `y` is sharded across 8 devices.
%timeit -n 5 -r 5 jnp.sin(y).block_until_ready()

Now let's look at each of these pieces in more detail!

Sharding describes how array values are laid out in memory across devices

Sharding basics, and the NamedSharding subclass

To parallelize computation across multiple devices, we first must lay out input data across multiple devices.

In JAX, Sharding objects describe distributed memory layouts. They can be used with jax.device_put to produce a value with distributed layout.

For example, here's a value with a single-device Sharding:

python
import jax
x = jax.random.normal(jax.random.key(0), (8192, 8192))
python
jax.debug.visualize_array_sharding(x)

Here, we're using the jax.debug.visualize_array_sharding function to show where the value x is stored in memory. All of x is stored on a single device, so the visualization is pretty boring!

But we can shard x across multiple devices by using jax.device_put and a Sharding object. First, we make a numpy.ndarray of Devices using jax.make_mesh, which takes hardware topology into account for the Device order:

python
from jax.sharding import Mesh, PartitionSpec, NamedSharding

P = PartitionSpec

mesh = jax.make_mesh((4, 2), ('a', 'b'))
y = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
jax.debug.visualize_array_sharding(y)

We can define a helper function to make things simpler:

python
default_mesh = jax.make_mesh((4, 2), ('a', 'b'))

def mesh_sharding(
    pspec: PartitionSpec, mesh: Optional[Mesh] = None,
  ) -> NamedSharding:
  if mesh is None:
    mesh = default_mesh
  return NamedSharding(mesh, pspec)
python
y = jax.device_put(x, mesh_sharding(P('a', 'b')))
jax.debug.visualize_array_sharding(y)

Here, we use P('a', 'b') to express that the first and second axes of x should be sharded over the device mesh axes 'a' and 'b', respectively. We can easily switch to P('b', 'a') to shard the axes of x over different devices:

python
y = jax.device_put(x, mesh_sharding(P('b', 'a')))
jax.debug.visualize_array_sharding(y)
python
# This `None` means that `x` is not sharded on its second dimension,
# and since the Mesh axis name 'b' is not mentioned, shards are
# replicated across it.
y = jax.device_put(x, mesh_sharding(P('a', None)))
jax.debug.visualize_array_sharding(y)

Here, because P('a', None) doesn't mention the Mesh axis name 'b', we get replication over the axis 'b'. The None here is just acting as a placeholder to line up against the second axis of the value x, without expressing sharding over any mesh axis. (As a shorthand, trailing Nones can be omitted, so that P('a', None) means the same thing as P('a'). But it doesn't hurt to be explicit!)

To shard only over the second axis of x, we can use a None placeholder in the PartitionSpec:

python
y = jax.device_put(x, mesh_sharding(P(None, 'b')))
jax.debug.visualize_array_sharding(y)
python
y = jax.device_put(x, mesh_sharding(P(None, 'a')))
jax.debug.visualize_array_sharding(y)

For a fixed mesh, we can even partition one logical axis of x over multiple device mesh axes:

python
y = jax.device_put(x, mesh_sharding(P(('a', 'b'), None)))
jax.debug.visualize_array_sharding(y)

Using NamedSharding makes it easy to define a device mesh once and give its axes names, then just refer to those names in PartitionSpecs for each device_put as needed.

Computation follows data sharding and is automatically parallelized

With sharded input data, the compiler can give us parallel computation. In particular, functions decorated with jax.jit can operate over sharded arrays without copying data onto a single device. Instead, computation follows sharding: based on the sharding of the input data, the compiler decides shardings for intermediates and output values, and parallelizes their evaluation, even inserting communication operations as necessary.

For example, the simplest computation is an elementwise one:

python
mesh = jax.make_mesh((4, 2), ('a', 'b'))
python
x = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
print('input sharding:')
jax.debug.visualize_array_sharding(x)

y = jnp.sin(x)
print('output sharding:')
jax.debug.visualize_array_sharding(y)

Here for the elementwise operation jnp.sin the compiler chose the output sharding to be the same as the input. Moreover, the compiler automatically parallelized the computation, so that each device computed its output shard from its input shard in parallel.

In other words, even though we wrote the jnp.sin computation as if a single machine were to execute it, the compiler splits up the computation for us and executes it on multiple devices.

We can do the same for more than just elementwise operations too. Consider a matrix multiplication with sharded inputs:

python
y = jax.device_put(x, NamedSharding(mesh, P('a', None)))
z = jax.device_put(x, NamedSharding(mesh, P(None, 'b')))
print('lhs sharding:')
jax.debug.visualize_array_sharding(y)
print('rhs sharding:')
jax.debug.visualize_array_sharding(z)

w = jnp.dot(y, z)
print('out sharding:')
jax.debug.visualize_array_sharding(w)

Here the compiler chose the output sharding so that it could maximally parallelize the computation: without needing communication, each device already has the input shards it needs to compute its output shard.

How can we be sure it's actually running in parallel? We can do a simple timing experiment:

python
x_single = jax.device_put(x, jax.devices()[0])
jax.debug.visualize_array_sharding(x_single)
python
np.allclose(jnp.dot(x_single, x_single),
            jnp.dot(y, z))
python
%timeit -n 5 -r 5 jnp.dot(x_single, x_single).block_until_ready()
python
%timeit -n 5 -r 5 jnp.dot(y, z).block_until_ready()

Even copying a sharded Array produces a result with the sharding of the input:

python
w_copy = jnp.copy(w)
jax.debug.visualize_array_sharding(w_copy)

So computation follows data placement: when we explicitly shard data with jax.device_put, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of JAX's policy of following explicit device placement.

When explicit shardings disagree, JAX errors

But what if two arguments to a computation are explicitly placed on different sets of devices, or with incompatible device orders? In these ambiguous cases, an error is raised:

python
import textwrap
from termcolor import colored

def print_exception(e):
  name = colored(f'{type(e).__name__}', 'red', force_color=True)
  print(textwrap.fill(f'{name}: {str(e)}'))
python
sharding1 = NamedSharding(Mesh(jax.devices()[:4], 'x'), P('x'))
sharding2 = NamedSharding(Mesh(jax.devices()[4:], 'x'), P('x'))

y = jax.device_put(x, sharding1)
z = jax.device_put(x, sharding2)
try: y + z
except ValueError as e: print_exception(e)
python
devices = jax.devices()
permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]]

sharding1 = NamedSharding(Mesh(devices, 'x'), P('x'))
sharding2 = NamedSharding(Mesh(permuted_devices, 'x'), P('x'))

y = jax.device_put(x, sharding1)
z = jax.device_put(x, sharding2)
try: y + z
except ValueError as e: print_exception(e)

We say arrays that have been explicitly placed or sharded with jax.device_put are committed to their device(s), and so won't be automatically moved. See the device placement FAQ for more information.

When arrays are not explicitly placed or sharded with jax.device_put, they are placed uncommitted on the default device. Unlike committed arrays, uncommitted arrays can be moved and resharded automatically: that is, uncommitted arrays can be arguments to a computation even if other arguments are explicitly placed on different devices.

For example, the output of jnp.zeros, jnp.arange, and jnp.array are uncommitted:

python
y = jax.device_put(x, sharding1)
y + jnp.ones_like(y)
y + jnp.arange(y.size).reshape(y.shape)
print('no error!')

Constraining shardings of intermediates in jitted code

While the compiler will attempt to decide how a function's intermediate values and outputs should be sharded, we can also give it hints using jax.lax.with_sharding_constraint. Using jax.lax.with_sharding_constraint is much like jax.device_put, except we use it inside staged-out (i.e. jit-decorated) functions:

python
mesh = jax.make_mesh((4, 2), ('x', 'y'))
python
x = jax.random.normal(jax.random.key(0), (8192, 8192))
x = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))
python
@jax.jit
def f(x):
  x = x + 1
  y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('y', 'x')))
  return y
python
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)
python
@jax.jit
def f(x):
  x = x + 1
  y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
  return y
python
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)

By adding with_sharding_constraint, we've constrained the sharding of the output. In addition to respecting the annotation on a particular intermediate, the compiler will use annotations to decide shardings for other values.

It's often a good practice to annotate the outputs of computations, for example based on how the values are ultimately consumed.

Examples: neural networks

⚠️ WARNING: The following is meant to be a simple demonstration of automatic sharding propagation with jax.Array, but it may not reflect best practices for real examples. For instance, real examples may require more use of with_sharding_constraint.

We can use jax.device_put and jax.jit's computation-follows-sharding features to parallelize computation in neural networks. Here are some simple examples, based on this basic neural network:

python
import jax
import jax.numpy as jnp
python
def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jnp.maximum(outputs, 0)
  return outputs

def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
python
loss_jit = jax.jit(loss)
gradfun = jax.jit(jax.grad(loss))
python
def init_layer(key, n_in, n_out):
  k1, k2 = jax.random.split(key)
  W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
  b = jax.random.normal(k2, (n_out,))
  return W, b

def init_model(key, layer_sizes, batch_size):
  key, *keys = jax.random.split(key, len(layer_sizes))
  params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))

  key, *keys = jax.random.split(key, 3)
  inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
  targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))

  return params, (inputs, targets)

layer_sizes = [784, 8192, 8192, 8192, 10]
batch_size = 8192

params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)

8-way batch data parallelism

python
mesh = jax.make_mesh((8,), ('batch',))
python
sharding = NamedSharding(mesh, P('batch'))
replicated_sharding = NamedSharding(mesh, P())
python
batch = jax.device_put(batch, sharding)
params = jax.device_put(params, replicated_sharding)
python
loss_jit(params, batch)
python
step_size = 1e-5

for _ in range(30):
  grads = gradfun(params, batch)
  params = [(W - step_size * dW, b - step_size * db)
            for (W, b), (dW, db) in zip(params, grads)]

print(loss_jit(params, batch))
python
%timeit -n 5 -r 5 gradfun(params, batch)[0][0].block_until_ready()
python
batch_single = jax.device_put(batch, jax.devices()[0])
params_single = jax.device_put(params, jax.devices()[0])
python
%timeit -n 5 -r 5 gradfun(params_single, batch_single)[0][0].block_until_ready()

4-way batch data parallelism and 2-way model tensor parallelism

python
mesh = jax.make_mesh((4, 2), ('batch', 'model'))
python
batch = jax.device_put(batch, NamedSharding(mesh, P('batch', None)))
jax.debug.visualize_array_sharding(batch[0])
jax.debug.visualize_array_sharding(batch[1])
python
replicated_sharding = NamedSharding(mesh, P())
python
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params

W1 = jax.device_put(W1, replicated_sharding)
b1 = jax.device_put(b1, replicated_sharding)

W2 = jax.device_put(W2, NamedSharding(mesh, P(None, 'model')))
b2 = jax.device_put(b2, NamedSharding(mesh, P('model')))

W3 = jax.device_put(W3, NamedSharding(mesh, P('model', None)))
b3 = jax.device_put(b3, replicated_sharding)

W4 = jax.device_put(W4, replicated_sharding)
b4 = jax.device_put(b4, replicated_sharding)

params = (W1, b1), (W2, b2), (W3, b3), (W4, b4)
python
jax.debug.visualize_array_sharding(W2)
python
jax.debug.visualize_array_sharding(W3)
python
print(loss_jit(params, batch))
python
step_size = 1e-5

for _ in range(30):
    grads = gradfun(params, batch)
    params = [(W - step_size * dW, b - step_size * db)
              for (W, b), (dW, db) in zip(params, grads)]
python
print(loss_jit(params, batch))
python
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
jax.debug.visualize_array_sharding(W2)
jax.debug.visualize_array_sharding(W3)
python
%timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready()