docs/source/async_grpo_trainer.md
[!IMPORTANT] This trainer requires
vllm>=0.17.1andtransformers>=5.2.0. For distributed training, only FSDP2 is supported (DeepSpeed ZeRO is not).Currently,
vllmandtransformershave conflicting dependency constraints. To work around this, install vLLM first and then force-install transformers:bashpip install 'vllm>=0.17.1' pip install 'transformers>=5.2.0' --no-deps
[AsyncGRPOTrainer] implements the same GRPO algorithm but decouples rollout generation from training. A background worker continuously streams completions from a vLLM server while the training loop consumes them, so generation and gradient updates overlap instead of alternating. The API mirrors [GRPOTrainer] — for full details on the GRPO method itself (advantage computation, KL estimation, loss formulation, reward functions, etc.), see the GRPO Trainer documentation. Not all features from [GRPOTrainer] are available; refer to [AsyncGRPOConfig] for the supported parameters.
This trainer was contributed by Quentin Gallouédec and Amine Dirhoussi.
GRPOTrainer]In the standard [GRPOTrainer], generation and training are sequential: generate a batch, compute the loss, update weights, repeat. Even in vLLM colocate mode, where generation runs on the same GPUs, one phase must finish before the other begins.
[AsyncGRPOTrainer] separates these two concerns:
The rollout worker runs in a separate process spawned from the trainer, so reward computation never contends with the training loop for the GIL. This has two consequences for what you can pass as reward_funcs, tools, and environment_factory:
[!WARNING] Because we run the rollout worker in a separate process, everything passed to it is pickled. Each reward function, tool, and
environment_factory(and anything they close over) must therefore be picklable: use a module-level function,functools.partial, or a callable class instance. Lambdas and closures will raise aTypeErrorattrainer.train(). This is a difference from [GRPOTrainer], where reward functions are called in-process and closures work.The rollout process also runs with
CUDA_VISIBLE_DEVICES="", so it cannot use the GPU. A GPU-backed reward model (e.g. anAutoModelForSequenceClassificationscorer) still loads without error but silently falls back to CPU (note that in [GRPOTrainer], such a reward model shares the trainer's GPUs). Keep reward functions CPU-side and lightweight (verifiers likeaccuracy_reward, format/length checks).If you do need a GPU reward model, the recommended approach is to serve it behind its own inference engine (vLLM, TGI, …) on separate GPUs and have a lightweight, picklable reward function call it over HTTP. This keeps the reward model on its own device while the rollout process stays CPU-only, and it scales independently of the trainer.
After every weight_sync_steps training steps, the updated weights are transferred to the vLLM server via NCCL so that subsequent generations reflect the latest policy.
Because generation and training run concurrently, the training samples may have been generated by a slightly older version of the model. The max_staleness parameter controls how many weight updates a sample can lag behind before being discarded.
The number of concurrent requests sent to the vLLM server is controlled by max_inflight_tasks. By default it is set automatically to max_staleness × per_device_train_batch_size × gradient_accumulation_steps × num_processes — the maximum number of samples the trainer can consume before they become stale. Generating more than this is wasteful since the excess samples will be discarded.
# train_async_grpo.py
from datasets import load_dataset
from trl.experimental.async_grpo import AsyncGRPOTrainer
from trl.rewards import accuracy_reward
dataset = load_dataset("trl-lib/DeepMath-103K", split="train")
trainer = AsyncGRPOTrainer(
model="Qwen/Qwen3-4B",
reward_funcs=accuracy_reward,
train_dataset=dataset,
)
trainer.train()
The vLLM server and the trainer must run on separate GPUs. Use CUDA_VISIBLE_DEVICES to partition your GPUs. For example, with 2 GPUs, you can run the vLLM server on GPU 0 and the trainer on GPU 1 as follows:
# Terminal 1: vLLM server on GPU 0 (dev mode + NCCL weight transfer are required)
CUDA_VISIBLE_DEVICES=0 VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-4B \
--max-model-len 4096 \
--logprobs-mode processed_logprobs \
--weight-transfer-config '{"backend":"nccl"}'
[!TIP] Set
--max-model-lento the maximum total sequence length (prompt + completion) you expect. A lower value reduces GPU memory usage on the server, freeing more memory for the KV cache and increasing throughput. A good starting point is the prompt length plusmax_completion_lengthfrom your config.
# Terminal 2: training on GPU 1
CUDA_VISIBLE_DEVICES=1 accelerate launch train_async_grpo.py
This trainer is intentionally kept minimal and is not meant to grow into a general-purpose solution. If you need a feature that is not supported, we recommend cloning the repository and adapting the trainer to your needs directly. New features will only be considered when there is significant community demand.
[[autodoc]] trl.experimental.async_grpo.AsyncGRPOConfig
[[autodoc]] trl.experimental.async_grpo.AsyncGRPOTrainer
[[autodoc]] trl.experimental.async_grpo.async_grpo_trainer.RolloutWorkerProtocol