.agents/skills/paddle-design-distributed/references/distributed-primer.md
分布式训练的基本目标:多卡训练的数学结果与单卡训练完全等价。所有并行策略的设计都围绕这一等价性展开——通过切分计算和通信来还原单卡的语义。
分布式训练依赖 NCCL 提供的集合通信原语:
| 原语 | 语义 | 典型用途 |
|---|---|---|
| Broadcast | 一个进程的数据广播给所有进程 | 参数初始化同步 |
| AllGather | 每个进程收集所有进程的数据 | ZeRO Stage3 前向前收集完整权重 |
| AllReduce | 所有进程归约后广播结果 | DP 梯度同步 |
| Reduce | 所有进程归约到一个进程 | 聚合 loss |
| ReduceScatter | 归约后将结果切分分发 | ZeRO Stage2 梯度同步 |
关键恒等式:AllReduce = ReduceScatter + AllGather。理解这一点对分析 ZeRO 和 Sequence Parallel 的通信量至关重要。
fleet.meta_parallel)开发者显式指定切分方式,手动插入通信算子。灵活但代码量大、易出错。
import paddle.distributed.fleet as fleet
fleet.init(is_collective=True)
strategy = fleet.DistributedStrategy()
strategy.tensor_parallel = True
strategy.tensor_parallel_configs = {"tensor_parallel_degree": 2}
ProcessMesh + shard_tensor)用户标注 Tensor 的切分方式,框架自动推导通信。兼具易用性和灵活性。
import paddle.distributed as dist
mesh = dist.ProcessMesh([0, 1, 2, 3], dim_names=["x"])
x = dist.shard_tensor(x, mesh, [dist.Shard(0)]) # 沿 dim 0 切分
auto_parallel.Engine)基于静态图 IR,框架做全局优化(算子切分、通信插入、调度优化)。适合追求极致性能的场景。
from paddle.distributed.auto_parallel import Engine
engine = Engine(model, loss, optimizer, strategy=strategy)
engine.fit(train_dataset)
思路:每张卡持有完整模型副本,训练数据按 batch 维度切分。
流程:
通信量:每次迭代 AllReduce 全部梯度,通信量 = 2 * model_size(Ring AllReduce 下)。
局限:模型必须完整放入单卡显存。
Group Sharded 是 ZeRO(Zero Redundancy Optimizer)在 Paddle 中的实现,渐进式减少显存冗余。
显存节省:优化器状态占总显存的大头(Adam 为参数量的 2 倍),Stage 1 将其降为 1/N。
额外通信:相比 DP,Stage 3 增加了前向 + 反向各一次 AllGather,通信量约增加 50%,但显存占用可降至接近 1/N。
将单个算子的权重矩阵切分到多卡。以线性层 Y = XW 为例:
将权重 W 按列切分为 [W1, W2],分布在 2 张卡上:
将权重 W 按行切分,输入 X 也相应切分:
Transformer 中的典型组合:MLP 的第一个线性层用 Column Parallel,第二个用 Row Parallel,首尾各一次 AllReduce(或 f/g 共轭算子对消前向/反向各一次 AllReduce)。
将模型按层分为多个 stage,分配到不同卡上。
一次只有一张卡在计算,其余空闲。GPU 利用率 = 1/N,不实用。
将 mini-batch 拆分为多个 micro-batch:
bubble 比例:(num_stages - 1) / num_microbatches。需要同时保存所有 micro-batch 的激活值,显存占用大。
交错执行前向和反向:
优势:稳态阶段只需保存 num_stages 个 micro-batch 的激活值,相比 F-then-B 显存减少约 37.5%(4 stages 时)。bubble 比例与 F-then-B 相同。
Tensor Parallel 的扩展,专门针对 Transformer 架构中 不在 Tensor Parallel 范围内 的算子(LayerNorm、Dropout)。
思路:
通信转换:
收益:在不增加通信量的前提下,减少了 LayerNorm/Dropout 区域的激活值显存占用。