docs/source/distributed.fsdp.fully_shard.md
fully_shard)PyTorch FSDP2 (RFC) provides a fully sharded data parallelism (FSDP) implementation targeting performant eager-mode while using per-parameter sharding for improved usability
See the Getting Started with FSDP2 tutorial for more information.
If you are currently using FSDP1, consider migrating to FSDP2 using our migration guide.
The user contract for fully_shard(model) is as follows
For model initialization, fully_shard converts model.parameters() from plain torch.Tensor to DTensor in-place. The parameters are moved to the appropriate device according to the device mesh.
Before forward and backward passes, pre-forward/backward hooks are responsible for all-gathering the parameters and converting model.parameters() from DTensor to plain torch.Tensor.
After forward and backward passes, post-forward/backward hooks free the unsharded parameters (no communication needed) and convert model.parameters() from plain torch.Tensor back to DTensor.
For the optimizer, it must be initialized with the DTensor model.parameters(), and the optimizer step should be performed on DTensor parameters.
Call model(input) instead of model.forward(input) to trigger pre-forward
hooks to all-gather parameters. To make model.forward(input) work, users must
either call model.unshard() explicitly or use register_fsdp_forward_method(model, "forward")
to register the forward method for hooking.
fully_shard groups parameters together for a single all-gather. User should apply fully_shard in a bottom-up manner. For example, in a Transformer model, fully_shard should be applied to each layer before applying it to the root model. When applied to the root model, fully_shard excludes model.parameters() from each layer and groups the remaining parameters (e.g., embeddings, output projection) into a single all-gather group.
type(model) is "unioned" with FSDPModule in-place. For example, if model
is originally of type nn.Linear, then fully_shard changes type(model) from
nn.Linear to FSDPLinear in-place. FSDPLinear is an instance of both
nn.Linear and FSDPModule. It retains all methods of nn.Linear while also
exposing FSDP2-specific APIs under FSDPModule, such as reshard() and
unshard().
Fully Qualified Names (FQNs) for parameters remain unchanged. If we call
model.state_dict(), the FQNs are the same before and after applying
fully_shard. This is because fully_shard does not wrap the module but only
registers hooks to the original module.
Compared to PyTorch FSDP1 (FullyShardedDataParallel):
DTensor-based dim-0 per-parameter sharding for a simpler
sharding representation compared to FSDP1's flat-parameter sharding, while
preserving similar throughput performance. More specifically, FSDP2 chunks
each parameter on dim-0 across the data parallel workers (using
torch.chunk(dim=0)), whereas FSDP1 flattens, concatenates, and chunks a
group of tensors together, making reasoning about what data is present on
each worker and resharding to different parallelisms complex. Per-parameter
sharding provides a more intuitive user experience, relaxes constraints
around frozen parameters, and allows for communication-free (sharded) state
dicts, which otherwise require all-gathers in FSDP1.torch.Tensor.record_stream. This ensures
deterministic and expected memory usage and does not require blocking the CPU
like in FSDP1's limit_all_gathers=True.FSDPModule below for details.DTensor s to full state dicts themselves using DTensor
APIs like DTensor.full_tensor() or by using higher-level APIs like
PyTorch Distributed Checkpoint 's
distributed state dict APIs. Also, some other args have been removed; see
here for
details... currentmodule:: torch.distributed.fsdp
The frontend API is fully_shard that can be called on a module:
.. autofunction:: fully_shard
.. autoclass:: FSDPModule
:members:
:member-order: bysource
.. autoclass:: UnshardHandle
:members:
.. autofunction:: register_fsdp_forward_method
.. autoclass:: MixedPrecisionPolicy
:members:
.. autoclass:: OffloadPolicy
:members:
.. autoclass:: CPUOffloadPolicy
:members:
.. autofunction:: share_comm_ctx