docs/workers/fsdp_workers.rst
Last updated: 12/01/2025.
We support PyTorch FSDP Backend by implementing various workers for actor, critic, reference, rollout and reward models.
Pros
Readily support various models.
dtensor_weight_loader for weight synchronization between FSDP
and vLLM. While for hf_weight_loader, users can directly apply
any models supported both in HF and vLLM without any code change.Easy to organize the forward and backward computation for each model.
Cons
Due to the simplicity, we recommend using FSDP backend for algorithm research and prototyping.
ActorRolloutRefWorker ^^^^^^^^^^^^^^^^^^^^^
Actor/Rollout HybridEngine ''''''''''''''''''''''''''
.. code:: python
@register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self):
ONE_TO_ALL: when calling the init_model function from the driver
process, each worker (on a GPU) will execute the following model
initialization process.
The initialization details of HybridEngine, Actor and Rollout are highlighted below:
DataParallelPPOActor implements the simple PPO computation logics
when the model is built with FSDP, including compute log prob, model
update.vLLMRollout support generation with vLLM. We modify the vLLM
Engine and make it executed under SPMD to fit into our
WorkerGroup design.See source code <https://github.com/volcengine/verl/blob/main/verl/workers/fsdp_workers.py>_. for more information.
.. code:: python
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def generate_sequences(self, prompts: DataProto):
Dispatch.DP_COMPUTE_PROTO: The data will be dispatched and
collected along the DP dimension
In this function, the rollout model will perform auto-regressive generation and the actor model will recompute the old log prob for the generated response.
.. code:: python
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_actor(self, data: DataProto):
ReferenceModel ''''''''''''''
The reference model is initialized using the same function as the actor
model without initializing the HybridEngine and Optimizer. Then the
actor model is also wrapped by the DataParallelPPOActor.
.. code:: python
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_ref_log_prob(self, data: DataProto):
DataParallelPPOActor to compute the reference log
prob.CriticWorker and RewardWorker ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Quite similar to reference model. The CriticWorker will perform additional initialization for the Optimizer.
.. code:: python
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_values(self, data: DataProto):
.. code:: python
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_critic(self, data: DataProto):
.. code:: python
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_rm_score(self, data: DataProto):
We didn't support FSDP HybridShard. To support this, we may need to
construct a 2D device mesh and test the corresponding
dtensor_weight_loader and hf_weight_loader for each model.