network/comms.md
The intention of this chapter is not to show code examples and explain APIs for which there are many tutorials, but to have excellent visuals that explain how the various types of communication patterns work.
Point-to-point communications are the simplest type of communication where there is always a single sender and a single receiver.
For example, Pipeline Parallelism performs a point-to-point communication where the activations from the current vertical stage is sent to the next stage. So the current gpu performs send and the gpu holding the next stage performs recv.
PyTorch has send and recv for blocking, isend and irecv for non-blocking p2p comms. more.
Collective communications include either multiple senders and a single receiver, a single sender and multiple receivers or multiple senders and multiple receivers.
In the world of PyTorch typically each process is tied to a single accelerator, and thus accelerators perform collective communications via process groups. The same process may belong to multiple process groups.
PyTorch API example:
dist.broadcast(tensor, src, group): Copies tensor from src to all other processes. doc.
PyTorch API example:
dist.gather(tensor, gather_list, dst, group): Copies tensor from all processes in dst. doc
For example, this collective is used in ZeRO (Deepspeed and FSDP) to gather the sharded model weights before forward and backward calls.
PyTorch API example:
dist.all_gather(tensor_list, tensor, group): Copies tensor from all processes to tensor_list, on all processes. doc
PyTorch API example:
dist.reduce(tensor, dst, op, group): Applies op to every tensor and stores the result in dst. doc
PyTorch supports multiple reduction operations like: avg, sum, product, min, max, band, bor, bxor, and others - full list.
For example, this collective is used in DDP to reduce gradients between all participating ranks.
PyTorch API example:
dist.all_reduce(tensor, op, group): Same as reduce, but the result is stored in all processes. doc
PyTorch API example:
dist.scatter(tensor, scatter_list, src, group): Copies the i-th tensor scatter_list[i] to the i-th process. doc
For example, this collective is used in ZeRO (Deepspeed and FSDP) to efficiently reduce gradients across all participating ranks. This is 2x more efficient than all-reduce.
PyTorch API example:
reduce_scatter(output, input_list, op, group, async_op): Reduces, then scatters a list of tensors to all processes in a group. doc
For example, this collective is used in Deepspeed Sequence Parallelism for attention computation, and in MoE Expert Parallelism.
PyTorch API example:
dist.all_to_all(output_tensor_list, input_tensor_list, group): Scatters list of input tensors to all processes in a group and return gathered list of tensors in output list. doc
The collective communications may have a variety of different implementations, and comm libraries like nccl may switch between different algorithms depending on internal heuristics, unless overridden by users.
Given:
A naive broadcast will send N/B at each step. The total time to broadcast to k GPUs will take: (k-1)*N/B
Here is an example of how a ring-based broadcast is performed:
This algorithm splits N into S messages
At each step N/(S*B) is sent, which is S times less than the naive algorithm sends per step.
The total time to broadcast N bytes to k GPUs will take:
S*N/(S*B) + (k − 2)*N*/(S*B) = N*(S + k − 2)/(S*B)
and if split messages are very small so thatS>>k: S + k − 2 is ~S and then the total time is about N/B.
Ring-based all-reduce is done similarly to broadcast. The message is split into many small messages and each GPU sends a small message to the next GPU in parallel with other GPUs. all-reduce has to perform 2x steps than broadcast, because it performs a reduction - so the size of the message needs to be sent twice over the wire.
Moreover, the whole message can be first split into chunks, to make the process even more efficient. Here is the reduction of the first chunk:
Then the next chunk is done, until all smaller messages are reduced:
Here are some additional guides with good visuals:
NCCL-specific: