docs/advance/dpo_extension.rst
Last updated: 02/25/2025.
We already implemented the complete training pipeline of the PPO algorithms. To extend to other algorithms, we analyze the high-level principle to use verl and provide a tutorial to implement the DPO algorithm. Users can follow the similar paradigm to extend to other RL algorithms.
.. note:: Key ideas: Single process drives multi-process computation and data communication.
Step 1: Consider what multi-machine multi-GPU computations are needed
for each model, such as generate_sequence , compute_log_prob and
update_policy in the actor_rollout model. Implement distributed
single-process-multiple-data (SPMD) computation and encapsulate them
into APIs
Step 2: Based on different distributed scenarios, including FSDP and 3D parallelism in Megatron-LM, implement single-process control of data interaction among multi-process computations.
Step 3: Utilize the encapsulated APIs to implement the control flow
We use verl to implement a simple online DPO algorithm. The algorithm flow of Online DPO is as follows:
Step 1: What are the multi-machine multi-GPU computations
**Sample Generator**
Implementation details:
.. code:: python
from verl.single_controller.base import Worker
from verl.single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool
import ray
@ray.remote
class SampleGenerator(Worker):
def __init__(self, config):
super().__init__()
self.config = config
def generate_sequences(self, data):
pass
Here, ``SampleGenerator`` can be viewed as a multi-process pulled up by
``torchrun``, with each process running the same code (SPMD).
``SampleGenerator`` needs to implement a ``generate_sequences`` API for
the control flow to call. The implementation details inside can use any
inference engine including vllm, sglang and huggingface. Users can
largely reuse the code in
verl/verl/workers/rollout/vllm_rollout/vllm_rollout.py and we won't
go into details here.
**ReferencePolicy inference**
API: compute reference log probability
.. code:: python
from verl.single_controller.base import Worker
import ray
@ray.remote
class ReferencePolicy(Worker):
def __init__(self):
super().__init__()
self.model = Model()
def infer(self, data):
return self.model(data)
**Actor update**
API: Update actor model parameters
.. code:: python
from verl.single_controller.base import Worker
import ray
@ray.remote
class DPOActor(Worker):
def __init__(self):
super().__init__()
self.model = Model()
self.model = FSDP(self.model) # or other distributed strategy
self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)
self.loss_fn = xxx
def update(self, data):
self.optimizer.zero_grad()
logits = self.model(data)
loss = self.loss_fn(logits)
loss.backward()
self.optimizer.step()
**Notes: How to distinguish between control processes and distributed computation processes**
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- Control processes are generally functions directly decorated with
``@ray.remote``
- Computation processes are all wrapped into a ``RayWorkerGroup``.
Users can reuse most of the distribtued computation logics implemented
in PPO algorithm, including FSDP and Megatron-LM backend in
verl/verl/trainer/ppo.
Step 2: Based on different distributed scenarios, implement single-process control of multi-process data interaction
The core problem to solve here is how a single process sends data to
multiple processes, drives multi-process computation, and how the
control process obtains the results of multi-process computation.
First, we initialize the multi-process WorkerGroup in the control
process.
.. code:: python
@ray.remote(num_cpus=1) def main_task(config): # construct SampleGenerator resource_pool = RayResourcePool(process_on_nodes=[8] * 2) # 16 GPUs ray_cls = RayClassWithInitArgs(SampleGenerator, config=config) # put SampleGenerator onto resource pool worker_group = RayWorkerGroup(resource_pool, ray_cls)
# construct reference policy
As we can see, in the control process, multiple processes are wrapped
into a RayWorkerGroup. Inside this WorkerGroup, there is a
self._workers member, where each worker is a RayActor
(https://docs.ray.io/en/latest/ray-core/actors.html) of SampleGenerator.
ray_trainer.md also provide an implementation of
MegatronRayWorkerGroup.
Assuming the model is distributed using FSDP, and there is a batch of data on the control process, for data parallelism, the underlying calling process is:
.. code:: python
data = xxx data_list = data.chunk(dp_size)
output = [] for d in data_list: # worker_group._workers[i] is a SampleGenerator output.append(worker_group._workers[i].generate_sequences.remote(d))
output = ray.get(output) output = torch.cat(output)
Single process calling multiple processes involves the following 3 steps:
Frequently calling these 3 steps on the controller process greatly hurts code readability. In verl, we have abstracted and encapsulated these 3 steps, so that the worker's method + dispatch + collect can be registered into the worker_group
.. code:: python
from verl.single_controller.base.decorator import register
def dispatch_data(worker_group, data): return data.chunk(worker_group.world_size)
def collect_data(worker_group, data): return torch.cat(data)
dispatch_mode = { 'dispatch_fn': dispatch_data, 'collect_fn': collect_data }
@register(dispatch_mode=dispatch_mode) def generate_sequences(self, data): pass
In this way, we can directly call the method inside the worker through
the worker_group on the control (driver) process (which is a single
process):
.. code:: python
output = worker_group.generate_sequences(data)
This single line includes data splitting, data distribution and computation, and data collection.
Furthermore, the model parallelism size of each model is usually fixed,
including dp, tp, pp. So for these common distributed scenarios, we have
pre-implemented specific dispatch and collect methods,in decorator.py <https://github.com/volcengine/verl/blob/main/verl/single_controller/base/decorator.py>_, which can be directly used to wrap the computations.
.. code:: python
from verl.single_controller.base.decorator import register, Dispatch
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def generate_sequences(self, data: DataProto) -> DataProto: pass
Here it requires the data interface to be DataProto. Definition of
DataProto is in protocol.py <https://github.com/volcengine/verl/blob/main/verl/protocol.py>_.
Step 3: Main training loop
With the above training flows, we can implement the algorithm's control
flow. It is recommended that ``main_task`` is also a ray remote process.
.. code:: python
@ray.remote(num_cpus=1)
def main_task(config):
# construct SampleGenerator
resource_pool = RayResourcePool(process_on_nodes=[8] * 2) # 16 GPUs
ray_cls = RayClassWithInitArgs(SampleGenerator, config=config)
# put SampleGenerator onto resource pool
sample_gen = RayWorkerGroup(resource_pool, ray_cls)
# construct reference policy
ray_cls = RayClassWithInitArgs(ReferencePolicy)
ref_policy = RayWorkerGroup(resource_pool, ray_cls)
# construct actor
ray_cls = RayClassWithInitArgs(DPOActor)
dpo_policy = RayWorkerGroup(resource_pool, ray_cls)
dataloader = DataLoader()
for data in dataloader:
# generate data
data = sample_gen.generate_sequences(data)
# generate scores for each data
data = generate_scores(data)
# generate pairwise data using scores
data = generate_pairwise_data(data)
# generate ref_log_prob
data.batch['ref_log_prob'] = ref_policy.infer(data)
# update using dpo
dpo_policy.update(data)
# logging
Here, different ``WorkerGroups`` can be placed in the same resource pool or
in different resource pools using ``create_colocated_worker_cls``
similar as in `ray_trainer.py <https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/ray_trainer.py>`_.