megatron/core/distributed/fsdp/src/README.md
Megatron-FSDP is an NVIDIA-developed PyTorch extension that provides a high-performance implementation of Fully Sharded Data Parallelism (FSDP). It offers seamless cross-compatibility with major deep learning frameworks and parallelism libraries, making it easy to scale your PyTorch models across multiple GPUs and nodes.
Megatron-FSDP can provide up to 25% speed up and 23% memory savings compared to FSDP2.
fully_shard function for quick model parallelizationParamAndGradBuffer classAll-Gather (AG) and Reduce-Scatter (RS) collectives are designed to overlap with compute kernels. However, standard NCCL communication kernels can consume a significant number of GPU SMs (e.g., 16-32 SMs), "stealing" resources from compute (GEMM) kernels and reducing overall TFLOPS.pip install megatron-fsdp
Transform your PyTorch model to use Fully Sharded Data Parallelism with just a few lines:
import torch
from megatron_fsdp import (
fully_shard_model,
fully_shard_optimizer,
)
"""
Enable FSDP with Megatron-FSDP via the `fully_shard_*` API.
"""
# Shard your model.
model = fully_shard_model(
model,
fsdp_unit_modules=[
YourModelLayerClass,
"import.path.to.model.class.YourModelLayerClass",
],
...
)
# Shard your optimizer.
optimizer = fully_shard_optimizer(
torch.optim.Adam(model.parameters(), lr=1e-3)
)
# Your model is now ready for distributed training!
fully_shard / fully_shard_model / fully_shard_optimizer are simple entrypoints into MegatronFSDP.
fully_shard on all the sub-modules, just pass your sub-module classes or import paths to fully_shard!fully_shard_* is a two-line change when sharding the model and optimizer separately.fully_shard is a one-line change for previously-initialized models and optimizers.Compare this with FSDP2:
import torch
from torch.distributed.fsdp import fully_shard
# Your existing model and optimizer.
model = YourModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Enable FSDP with FSDP2.
for module in model.modules():
# Sub-Modules to shard.
if isinstance(module, YourModelLayerClass):
fully_shard(module)
fully_shard(model)
# Your model is now ready for distributed training!
torch.compile CompatibilityMegatron-FSDP is compatible with torch.compile, but this feature is still experimental and may introduce performance regressions in some workloads.
megatron_fsdp.import torch
from megatron_fsdp import (
fully_shard_model,
fully_shard_optimizer,
MixedPrecisionPolicy,
)
DeviceMesh.DeviceMesh simplifies the construction of complex arrangements of devices
to support various parallelisms.
from torch.distributed.device_mesh import DeviceMesh
# Initialize DeviceMesh.
device_mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda",
mesh_shape=(dp_outer_size, dp_shard_size, cp_size, tp_size),
mesh_dim_names=("dp_outer", "dp_shard", "cp", "tp"),
)
# Only relevant when using HSDP, where we also need the full DP group for data parallelism,
# This sub-mesh can be provided to distributed samplers or dataloaders.
device_mesh[("dp_outer", "dp_shard")]._flatten("dp")
# Only required if using CP. Otherwise, just pass dp_shard to FSDP.
device_mesh[("dp_shard", "cp")]._flatten("dp_shard_cp")
# Only required if using HSDP. Otherwise, don't pass hybrid_fsdp_group.
device_mesh[("dp_outer", "dp_shard", "cp")]._flatten("hsdp")
hsdp_group = device_mesh["hsdp"].get_group()
# Initialize DeviceMesh for expert parallel (EP) modules when using FSDP + EP.
expert_device_mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda",
mesh_shape=(dp_outer_size, expt_dp_shard_size, expt_tp_size),
mesh_dim_names=("dp_outer", "dp_shard_cp", "tp"),
)
expert_device_mesh[("dp_outer", "dp_shard_cp")].flatten("hsdp")
hsdp_expt_group = expert_device_mesh["hsdp"].get_group()
MegatronFSDP models with fully_shard_model.This wraps the model in a MegatronFSDP class that schedules the sharding lifecycle of the model parameters and gradients during training and inference.
model = fully_shard_model(
# PyTorch (Root) Module
model,
# Sharded Modules
fsdp_unit_modules=[...],
# Device Mesh
device_mesh=device_mesh
# Always required for FSDP or HSDP.
dp_shard_dim="dp_shard_cp",
# Set this required argument to use HSDP instead of FSDP. Otherwise, set this to None.
dp_outer_dim="dp_outer",
# Only required for TP-sensitive models (i.e. Megatron-LM / TransformerEngine)
# or when using DTensor-based TP. Otherwise, set this to None.
tp_dim="tp",
# Only required when using HSDP. Otherwise, set this to None.
hybrid_fsdp_group=hsdp_group,
# Only required when using HSDP + EP. Otherwise, set this to None.
hybrid_fsdp_expt_group=hsdp_expt_group,
# Only required for FSDP + EP. Otherwise, set this to None.
expt_device_mesh=expt_device_mesh,
# FSDP Sharding Strategy: no_shard (0) / optim (1) / optim_grads (2) / optim_grads_params (3)
zero_dp_strategy=3,
outer_dp_sharding_strategy=1,
# Initialize the model on devices in shards to avoid OOM. Requires device("meta")-init for model.
init_model_with_meta_device=True,
# Mixed-Precision Policy for controlling compute and communication precision in Megatron-FSDP.
mixed_precision_policy=MixedPrecisionPolicy(),
# Sync parameters and gradients each step. Allows for gradient transformations after backward pass,
# and synchronizes parameters and gradients across HSDP groups, but deactivates compute-communication
# overlap going into the subsequent training step.
sync_model_each_microbatch=True,
# Preprocess state dict for DCP checkpointing. Required for Torch Distributed Checkpoint.
preproc_state_dict_for_dcp_ckpt=True,
)
The original torch.nn.Module can be accessed at MegatronFSDP.module.
MegatronFSDP model.Initialize your optimizer on the Megatron-FSDP model distributed Parameter(s).
If your optimizer has already been initialized, either use the fully_shard
entrypoint, or use optimizer.add_param_group({"params": model.parameters()})
after resetting your optimizer state via optimizer.param_groups.clear()
and optimizer.state.clear().
optimizer = torch.optim.Optimizer(model.parameters())
fully_shard_optimizer modifies your optimizer.step(), optimizer.zero_grad(),
and distributed optimizer parameters to punctually trigger scheduled FSDP operations
for Megatron-FSDP.
fully_shard_optimizer(
# PyTorch Optimizer
optimizer,
# Preprocess state dict for DCP checkpointing.
# Required for Torch Distributed Checkpoint.
preproc_state_dict_for_dcp_ckpt=True,
)
Extended arguments to step() and zero_grad() control these FSDP operations:
optimizer.step(
...,
# Sync all gradients before the optimizer step. Alternatively enabled using
# `sync_model_each_microbatch=True` in MegatronFSDP.
sync_grad_before_optimizer_step=True,
# After `optimizer.step()`, install optimized weights into MegatronFSDP's buffers.
install_optimized_model_weights=True,
)
optimizer.zero_grad(
...,
# Also zero out MegatronFSDP's gradient accumulation buffers.
zero_grad_buffer=True
)
MegatronFSDP Distributed CheckpointingDistributed checkpoints can be saved and loaded using Torch DCP. Alternatively, you can load non-distributed checkpoints before fully-sharding your model with any existing checkpoint utility compatible with PyTorch Modules.
# Save model and optimizer state.
torch.distributed.checkpoint.save(
{"model": model.state_dict(), "optimizer": optimizer.state_dict()},
checkpoint_id=str(CKPT_DIR)
)
# Load model and optimizer state.
ckpt_state_dict = {"model": model.state_dict(), "optimizer": optimizer.state_dict()}
torch.distributed.checkpoint.load(state_dict=ckpt_state_dict, checkpoint_id=str(CKPT_DIR))
# `model.load_state_dict(strict=False)` is only necessary to ignore TE FP8 extra state
# that is missing from the DCP checkpoint but present in TEBaseModule.
# Megatron-FSDP does not support TE FP8 extra state checkpointing with DCP.
model.load_state_dict(ckpt_state_dict["model"], strict=False)
optimizer.load_state_dict(ckpt_state_dict["optimizer"])
fully_shard / MegatronFSDP API - Advanced FeaturesMegatron-FSDP's fully_shard_* API has a comprehensive set of arguments for fine-tuning your model's performance.
fsdp_unit_modules is a list of sub-module classes or str import-paths associated with modules that you want MegatronFSDP to fully-shard.
1, 2, or 3 are specified as the sharding strategy. Defaults to None, in which case Megatron-FSDP will replicate the parameters similar to DDP.zero_dp_strategy (and outer_dp_sharding_strategy) configure different degrees of zero-redundancy data parallelism as described in ZeRO (Zero Redundancy Optimizer). It reduces CUDA memory utilization during model training by distributing model parameters, gradients, and optimizer states across multiple devices in the DP ProcessGroup, and collectively communicating subsets of parameters and gradients to specific devices when needed for computation or differentiation. More aggressive sharding strategies will entail more communication overhead, with no_shard being the least memory efficient but most communication efficient, and optim_grads_params being the most memory efficient but least communication efficient. Additionally, outer_dp_sharding_strategy supports no_shard (Hybrid-Sharded Data Parallelism (HSDP)) and optim (HFSDP = Fully-Sharded Optimizer State + HSDP, requires zero_dp_strategy='optim_grads_params'), after specifying the "outer" DP group (dp_outer_dim / hybrid_fsdp_group).
optim_grads_params or 3 for zero_dp_strategy and no_shard or 0 for outer_dp_sharding_strategy0 or no_shard implies that your model is not sharded. Similar memory usage to DDP.1 or optim implies that your optimizer state is sharded for distributed optimization. Similar to optimizer state sharding in ZeRO-DP.2 or optim_grads implies that your optimizer state and gradients are sharded. Similar to ZeRO-2.3 or optim_grads_params implies that your optimizer state, gradients, and training parameters are sharded. Similar to ZeRO-3.device_mesh is a torch.distributed.DeviceMesh that informs MegatronFSDP of your distributed environment for sharding in conjunction with hardware configuration and other parallelisms. If not provided, megatron_fsdp.fully_shard(_model) will build an FSDP DeviceMesh for you automatically.
dp_shard_dim is the name of the sub-mesh required for FSDP sharding, and is commonly the flattened combination of the data parallel (DP) and context parallel (CP) sub-meshes.
dp_outer_dim is the name of the sub-mesh corresponding to the "outer" DP group, which is required for replication or sharding in HSDP. fully_shard will perform HSDP if dp_outer_dim is specified.tp_dim is the name of the sub-mesh used for tensor parallelism (TP), which is required for (FSDP, TP)-strided sharding when using Megatron-LM or Torch-native DTensor TP.
hybrid_fsdp_group is the ProcessGroup which contains all ranks in the flattened dp_shard_dim and dp_outer_dim sub-meshes utilized to specify the (DP-Outer, DP-Shard) sharded mesh coordinates for the weight and gradient buffers. Required for HSDP.hybrid_fsdp_expt_group defines the data-parallel communication group for expert parameters. It is required for HSDP.expt_device_mesh is another torch.distributed.DeviceMesh tailored for the expert parallel (EP) modules in MegatronFSDP.
dp_shard_dim is the name of the sub-mesh required for FSDP sharding of the EP modules, enabling expert data parallelism (EDP).tp_dim is the name of the sub-mesh used for expert tensor parallelism (ETP), which is required for (FSDP, ETP)-strided sharding when using Megatron-LM or Torch-native DTensor ETP.init_model_with_meta_device has MegatronFSDP initialize your meta-device model in shards on every CUDA device to avoid OOM when initializing extremely large models that cannot fit on a single device. Users can initialize their model on a meta-device (with torch.device('meta'): ...), and MegatronFSDP will further shard and initialize the model parameters layer-by-layer adhering to the customizable module.reset_parameters method, which prevents the entire model from being allocated in memory at any point during runtime.
False.device argument which installs your model on a specific device or rank will be deactivated when init_model_with_meta_device=True.mixed_precision_policy takes a megatron_fsdp.MixedPrecisionPolicy that configures mixed-precision compute and communication for Megatron-FSDP. Configuration options include:
main_params_dtype controls the data-type for parameters used in distributed optimization or quantization.
torch.float32.None, the native model compute parameter data-type will be utilized.None) when using FP8 parameters with Megatron-FSDP.main_grads_dtype controls the data-type for gradients used in distributed optimization.
None, the model native gradient data-type will be utilized.torch.float32 (or higher) is recommended for accuracy at scale, as main_grads_dtype controls the data-type for gradient accumulation, None is more flexible and uses pre-determined parameter gradient logic in mixed-precision scenarios, such as BF16 for FP8/FP4 parameters quantized via TransformerEngine.grad_comm_dtype controls the data-type for gradient communications (RS / AR) when reducing gradients. Lower precision grad_comm_dtype improves (communication) performance, but may increase memory utilization or sacrifice gradient precision in certain cases.
None, the main_grads_dtype data-type will be utilized, and no additional memory is allocated when grad_comm_dtype == main_grads_dtype.outer_dp_sharding_strategy), no_shard, optim, or a FixedPoolAllocator (fsdp_double_buffer), allocating dtype-custom gradient communication buffers (per FSDP group) adds memory overhead of up to 10% or more, and users should consider the performance-memory trade-off when using this feature.nccl_ub=True), gradient reduction may be performed in high-precision depending on the network domain (NVLink or IB), and can enable mixed-precision communication and accumulation, e.g. setting grad_comm_dtype to BF16 can support FP32 reduction even though we have BF16 input and output communication buffers. Otherwise, gradients will be reduced in grad_comm_dtype (and accumulated in main_grads_dtype) as usual.overlap_grad_reduce and overlap_param_gather will overlap gradient reduce-scatter and parameter all-gather group communications with backward and forward compute with asynchronous calls and pre-fetching. (In the case of no_shard, parameters are not gathered but gradient all-reduce is overlapped.)
True.sync_model_each_microbatch will trigger a wait (MegatronFSDP.finish_grad_sync()) on gradient reduction, parameter de-allocation, and optimizer parameter / gradient installation (in preparation for optimizer.step()) after every forward-backward pass. When using HSDP, parameters and gradients will be all-gathered and reduced respectively on the "outer" DP group each training step instead of each optimization cycle. This behavior is desirable for a transparent and user-friendly sharded training loop where post-backward transformations on the gradient and a clean compute / memory state are necessary within and between training iterations, but damages performance in situations where optimization is delayed (e.g. gradient accumulation) when the communications of the previous training iteration can be overlapped with the compute of the next training iteration. Will also override is_last_microbatch / microbatch_count logic in MegatronFSDP.
True for fully_shard, but defaults to False when using the MegatronFSDP class directly.MegatronFSDP.sync() context manager, or through invoking MegatronFSDP.set_model_auto_sync(bool).no_shard / 0 or optim / 1 sharding strategies, the user is responsible for calling MegatronFSDP.zero_grad_buffer() or optimizer.zero_grad() after the subsequent forward-backward pass. This is because un-sharded gradients are all-reduced directly into the gradient accumulation buffer, and this buffer should not be all-reduced more than once per optimization cycle! Analogous to the justification for the no_sync() API for PyTorch DistributedDataParallel.enable_fine_grained_param_gather modifies FSDP to all-gather parameters with per-Module granularity instead of collectively unsharding all sub-modules of a unit module in Megatron-FSDP.
False.keep_fp8_transpose_cache will keep the fp8 transpose cache when using MegatronFSDP. This option will cause (number of parameter $\times$ 1 Byte) of memory overhead, but can skip the weight transpose operation in the backward propagation. This feature will not give any benefit from the Blackwell architecture.
False.nccl_ub will allocate and register the NCCL userbuffer for param and grad buffers. This option enables an SM-efficient NCCL algorithm that could improve the performance of overlapped computations. This flag will be much more effective when used together with SHARP if the FSDP communication includes both NVL and IB domains. Enabling this option will cause additional memory overhead due to the requirement to enable the fsdp_double_buffer option.
False.fsdp_manual_registration will manually register the FSDP communication buffers with the NCCL user buffer. For symmetric registration with large models, the registration itself can take a significant amount of time. This option minimizes the number of registration calls to reduce the registration time. However, with this option enabled, you need to manually call the ParamAndGradBuffer.manual_buffer_registration() function after the first iteration. This is already implemented in the Megatron-LM training loop. In other use cases, users are expected to call this function themselves.
nccl_ub is enabled.False.disable_symmetric_registration will disable NCCL window (i.e. symmetric) registration when using nccl_ub.
False.fsdp_double_buffer will use persistently allocated double buffers for temporarily-defined memory needed in MegatronFSDP communications. Having persistent double buffers may increase peak VRAM utilization, but is required to register NCCL user buffers (nccl_ub=True) for MegatronFSDP. Currently, this is only supported for simple repetitive model structures such as GPT.
False. Automatically overridden to True when nccl_ub is enabled.preproc_state_dict_for_dcp_ckpt adds model.state_dict() and optimizer.state_dict() post-hooks that modify the model and optimizer state in preparation for torch.distributed.checkpoint.{save,load} (Torch DCP) checkpointing. Specifically, it adds __create_write_items__ and __create_chunk_list__ methods to Tensors utilized by Torch DCP to redistribute parameters when saving and loading model and optimizer checkpoints. Can be deactivated should the user need a custom distributed checkpointing strategy.
True.TransformerEngineMegatron-FSDP natively supports mixed-precision activations and parameter sharding in conjunction with TransformerEngine.
transformer_engine.pytorch.autocast(recipe: transformer_engine.common.recipe.Recipe) context, model activations are converted based on the recipe.transformer_engine.pytorch.quantized_model_init(recipe: transformer_engine.common.recipe.Recipe) context, TransformerEngine native modules (e.g. transformer_engine.pytorch.TransformerLayer) have their parameters converted based on the recipe.
transformer_engine.pytorch.autocast.# FP8 Recipe
fp8_recipe = transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.HYBRID,
)
# Construct TransformerEngine model with FP8 parameters.
with transformer_engine.pytorch.quantized_model_init(
recipe=fp8_recipe,
# Needed for FP8 parameters with Megatron-FSDP.
preserve_high_precision_init_val=True,
):
te_model = transformer_engine.pytorch.TransformerLayer(...)
# Fully-shard the model.
mfsdp_model = fully_shard_model(
module=te_model,
fsdp_unit_modules=[te.pytorch.TransformerLayer],
# Only FSDP / ZeRO-3 supports FP8 parameters.
zero_dp_strategy=3,
# FP32 main weights needed for FP8 parameters.
mixed_precision_policy=MixedPrecisionPolicy(
main_params_dtype=torch.float32
),
# Needed for select FP8 recipes.
keep_fp8_transpose_cache=True,
)
# Evaluate and differentiate the model with FP8 activations.
with transformer_engine.pytorch.autocast(recipe=fp8_recipe):
mfsdp_model(x).sum().backward()
ā¹ļø TransformerEngine kernels have a fair bit of configuration constraints when using FP8-quantized parameters, such as using fused QKV parameters or defining activations and parameters with shapes compatible to FP8 CuBLAS kernels on supported hardware from NVIDIA. To properly initialize TransformerLayer, you can refer to the toy model used in our FP8 unit tests: Megatron-LM/tests/unit_tests/distributed/fsdp/test_mfsdp_fully_shard.py::TestMegatronFsdpFullyShard::test_fully_shard_te_quantized.