docs/src/examples/data_parallelism.rst
.. _data_parallelism:
MLX enables efficient data parallel distributed training through its distributed communication primitives.
.. _training_example:
In this section we will adapt an MLX training loop to support data parallel distributed training. Namely, we will average the gradients across a set of hosts before applying them to the model.
Our training loop looks like the following code snippet if we omit the model, dataset, and optimizer initialization.
.. code:: python
model = ...
optimizer = ...
dataset = ...
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
optimizer.update(model, grads)
return loss
for x, y in dataset:
loss = step(model, x, y)
mx.eval(loss, model.parameters())
All we have to do to average the gradients across machines is perform an
:func:all_sum and divide by the size of the :class:Group. Namely we
have to :func:mlx.utils.tree_map the gradients with following function.
.. code:: python
def all_avg(x):
return mx.distributed.all_sum(x) / mx.distributed.init().size()
Putting everything together our training loop step looks as follows with everything else remaining the same.
.. code:: python
from mlx.utils import tree_map
def all_reduce_grads(grads):
N = mx.distributed.init().size()
if N == 1:
return grads
return tree_map(
lambda x: mx.distributed.all_sum(x) / N,
grads
)
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = all_reduce_grads(grads) # <--- This line was added
optimizer.update(model, grads)
return loss
Using nn.average_gradients
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Although the code example above works correctly; it performs one communication per gradient. It is significantly more efficient to aggregate several gradients together and perform fewer communication steps.
This is the purpose of :func:mlx.nn.average_gradients. The final code looks
almost identical to the example above:
.. code:: python
model = ...
optimizer = ...
dataset = ...
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = mx.nn.average_gradients(grads) # <---- This line was added
optimizer.update(model, grads)
return loss
for x, y in dataset:
loss = step(model, x, y)
mx.eval(loss, model.parameters())