Back to Verl

PPO Ray Trainer

docs/workers/ray_trainer.rst

0.7.111.5 KB
Original Source

PPO Ray Trainer

Last updated: 02/12/2025.

We implement the RayPPOTrainer, which is a trainer runs on the driver process on a single CPU/GPU node (default is CPU).

The PPORayTrainer include 3 core functions for data preparation, WorkerGroup initialization and PPO training loop.

Data Preparation

The PPORayTrainer, as a single process, is responsible for loading a complete batch of samples (prompts) from the dataset and then dispatch to different worker_groups running on different GPUs.

To generalize the data loading, we implement the RLHFDataset class to load the preprocessed parquet files, apply chat templates to the prompts, add padding, truncate prompts that exceed max prompt length and then tokenize.

.. code:: python

self.train_dataset = RLHFDataset(data_files=self.config.data.train_files, tokenizer=self.tokenizer, config=self.config.data)

Then, the dataloader will iterate the dataset under PPO mini batch size.

WorkerGroup Initialization

We first introduce a basic implementation of initializing the WorkerGroup of the actor model on a given set of GPUs.

.. code:: python

max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool

For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.

For Megatron backend, we recommend using max_colocate_count>1 that can utilize different WorkerGroup for differnt models

resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes, use_gpu=True, max_colocate_count=1)

define actor rollout cls to be init on remote

actor_rollout_cls = RayClassWithInitArgs(cls=ActorRolloutWorker)

define actor_rollout worker group

actor_rollout_worker_group = MegatronRayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_rollout_cls, default_megatron_kwargs=config.actor_rollout.megatron)

Different WorkerGroups, like actor_rollout_worker_group , critic_worker_group and ref_worker_group lies on a separate process in the above implementation.

The driver process can then call the distributed compute function within the actor_rollout_worker_group and other roles to construct the RL training loop.

For models colocated in the same set of GPUs, we further provide a fine-grain optimization, which merge the worker_group of different roles in the same process. This optimization can save the redundant CUDA/distributed context in different processes.

.. code:: python

initialize WorkerGroup

NOTE: if you want to use a different resource pool for each role, which can support different parallel size,

you should not use create_colocated_worker_cls. Instead, directly pass different resource pool to different worker groups.

See TODO(url) for more information.

all_wg = {} for resource_pool, class_dict in self.resource_pool_to_cls.items(): worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) all_wg.update(spawn_wg)

if self.use_critic: self.critic_wg = all_wg['critic'] self.critic_wg.init_model()

if self.use_reference_policy: self.ref_policy_wg = all_wg['ref'] self.ref_policy_wg.init_model()

if self.use_rm: self.rm_wg = all_wg['rm'] self.rm_wg.init_model()

we should create rollout at the end so that vllm can have a better estimation of kv cache memory

self.actor_rollout_wg = all_wg['actor_rollout'] self.actor_rollout_wg.init_model()

.. note:: For megatron backend, if we merge the worker_groups into the same processes, all the roles will utilize the same 3D parallel size. To optimize this, we may need to maintain several 3D process groups for each role in the same distributed context. If you want to use different 3D parallel size for different roles, please follow the similar architecture of the first code block to initialize each role's worker_group

PPO Training Loop

We implement the PPO training loop by calling the functions in worker_group of each role. The input and output data of each function is a DataProto object implemented in protocol.py <https://github.com/volcengine/verl/blob/main/verl/protocol.py>_. In the training loop, trainer will dispatch/collect the data to/from different GPUs following the transfer protocols wrapped in the workers' functions. The computation of PPO micro batches is processed in update_actor and update_critic functions.

To extend to other RLHF algorithms, such as DPO, GRPO, please refer to :doc:../advance/dpo_extension.

.. code:: python

def fit(self): """ The training loop of PPO. The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. The light-weight advantage computation is done on the driver process. """ from verl.utils.tracking import Tracking from omegaconf import OmegaConf

   logger = Tracking(project_name=self.config.trainer.project_name,
                       experiment_name=self.config.trainer.experiment_name,
                       default_backend=self.config.trainer.logger,
                       config=OmegaConf.to_container(self.config, resolve=True))

   global_steps = 0

   # perform validation before training
   # currently, we only support validation using the reward_function.
   if self.val_reward_fn is not None:
       val_metrics = self._validate()
       pprint(f'Initial validation metrics: {val_metrics}')

   for epoch in range(self.config.trainer.total_epochs):
       for batch_dict in self.train_dataloader:
           metrics = {}

           batch: DataProto = DataProto.from_single_dict(batch_dict)
           # batch = batch.to('cuda')

           # pop those keys for generation
           gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])

           # generate a batch
           with Timer(name='gen', logger=None) as timer:
               gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
           metrics['timing/gen'] = timer.last

           batch = batch.union(gen_batch_output)

           if self.use_reference_policy:
               # compute reference log_prob
               with Timer(name='ref', logger=None) as timer:
                   ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
                   batch = batch.union(ref_log_prob)
               metrics['timing/ref'] = timer.last

           # compute values
           with Timer(name='values', logger=None) as timer:
               values = self.critic_wg.compute_values(batch)
               batch = batch.union(values)
           metrics['timing/values'] = timer.last

           with Timer(name='adv', logger=None) as timer:
               # compute scores. Support both model and function-based.
               # We first compute the scores using reward model. Then, we call reward_fn to combine
               # the results from reward model and rule-based results.
               if self.use_rm:
                   # we first compute reward model score
                   reward_tensor = self.rm_wg.compute_rm_score(batch)
                   batch = batch.union(reward_tensor)

               # we combine with rule-based rm
               reward_tensor = self.reward_fn(batch)
               batch.batch['token_level_scores'] = reward_tensor

               # compute rewards. apply_kl_penalty if available
               batch, kl_metrics = apply_kl_penalty(batch,
                                                       kl_ctrl=self.kl_ctrl_in_reward,
                                                       kl_penalty=self.config.algorithm.kl_penalty)
               metrics.update(kl_metrics)

               # compute advantages, executed on the driver process
               batch = compute_advantage(batch,
                                           self.config.algorithm.gamma,
                                           self.config.algorithm.lam,
                                           adv_estimator=self.config.algorithm.adv_estimator)
           metrics['timing/adv'] = timer.last

           # update critic
           if self.use_critic:
               with Timer(name='update_critic', logger=None) as timer:
                   critic_output = self.critic_wg.update_critic(batch)
               metrics['timing/update_critic'] = timer.last
               critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
               metrics.update(critic_output_metrics)

           # implement critic warmup
           if self.config.trainer.critic_warmup <= global_steps:
               # update actor
               with Timer(name='update_actor', logger=None) as timer:
                   actor_output = self.actor_rollout_wg.update_actor(batch)
               metrics['timing/update_actor'] = timer.last
               actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
               metrics.update(actor_output_metrics)

           # validate
           if self.val_reward_fn is not None and (global_steps + 1) % self.config.trainer.test_freq == 0:
               with Timer(name='testing', logger=None) as timer:
                   val_metrics: dict = self._validate()
                   val_metrics = {f'val/{key}': val for key, val in val_metrics.items()}
               metrics['timing/testing'] = timer.last
               metrics.update(val_metrics)

           # collect metrics
           data_metrics = compute_data_metrics(batch=batch)
           metrics.update(data_metrics)

           # TODO: make a canonical logger that supports various backend
           logger.log(data=metrics, step=global_steps)

           if self.config.trainer.save_freq > 0 and (global_steps + 1) % self.config.trainer.save_freq == 0:
               actor_local_path = os.path.join(self.config.trainer.default_local_dir, 'actor',
                                               f'global_step_{global_steps}')
               actor_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'actor')
               self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path)

               if self.use_critic:
                   critic_local_path = os.path.join(self.config.trainer.default_local_dir, 'critic',
                                                       f'global_step_{global_steps}')
                   critic_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'critic')
                   self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path)

           global_steps += 1

   # perform validation after training
   if self.val_reward_fn is not None:
       val_metrics = self._validate()
       pprint(f'Final validation metrics: {val_metrics}')