docs/advance/async-on-policy-distill.md
Authors: Brilliant Hanabi, furunding
Last updated: 2025-11-08
On-policy knowledge distillation (KD) trains a student policy to imitate a stronger teacher using samples drawn from the student's current policy. For each on-policy rollout the teacher returns soft, top-k token distributions and the student is optimized with a token-wise sparse KL objective that focuses learning on the teacher's high-probability modes. Because training examples come from the student's own state distribution, KD reduces distributional mismatch relative to off-policy distillation or supervised fine-tuning (SFT), improving stability and sample efficiency. Compared with reinforcement learning, KD avoids high-variance reward-based optimization and complex reward design by providing dense, informative per-token targets, which typically yields faster convergence and simpler scaling. Recent empirical and implementation-focused writeups (e.g., ThinkingMachines' blog on on-policy distillation) also demonstrate that on-policy distillation can deliver high-quality behavior with substantially lower compute and data requirements than many alternative approaches.
Built on verl’s Ray-based single-controller components, we initially assembled a strictly on-policy KD pipeline where rollout generation, teacher knowledge acquisition, and policy optimization ran in lockstep. In practice, this synchronous design proved highly inefficient: the three stages had to wait for one another, creating pipeline bubbles and underutilized GPUs. To address this, we extend the asynchronous schedulers introduced by the One-Step-Off Policy pipeline to overlap these phases. This overlap preserves the same distillation objective while trading some strict on-policy guarantees for substantial gains in end-to-end throughput and hardware utilization.
This recipe centers on on-policy knowledge distillation: the student policy learns from a stronger teacher on samples generated by the current policy (on-policy). For each input prompt, the student (actor) generates responses; the teacher provides top-k token distributions, and the student is trained to match them token-wise.
Core components:
Objective: encourage student probabilities $Q$ to cover teacher modes $P$ using token-wise $\mathrm{KL}(P,|,Q)$ computed on the teacher's top-k support.
The native (serial) on-policy distillation process is shown in the figure below.
This recipe supports optional schedulers that overlap generation, teacher querying, and updates to improve throughput without changing the distillation objective.
sync_rollout_weights, wait_prev_gen, wait_prev_teacher.sync_rollout_weights, max(wait_prev_gen, wait_prev_prev_teacher).Tip: Use two_step_off when teacher takes much more time than sync; one_step_off for simpler overlapping.
Practical details:
teacher_topk_logps, teacher_topk_indices, attention_mask (to select valid token positions).The pipeline:
async_generate_sequences).megatron_kl_loss.py).one_step_off_scheduler, two_step_off_scheduler) can overlap phases (optional for throughput):We initially followed the weight synchronization path from the One-Step-Off-Policy recipe (Ray collective broadcast across all actor and rollout ranks, plus Megatron-side allgather of parameter shards). In practice this became the dominant bottleneck, so we made three changes:
Driver (TaskRunner)
├─ Initialize Ray, tokenizer, datasets, worker groups
├─ Build ResourcePoolManager (actor vs rollout GPU layouts)
├─ Trainer.fit()
├─ init_workers(): build actor + rollout groups, broadcast weight metadata, create nccl collective group
├─ continuous_iterator(): epochs → batches
├─ scheduler (see Section 6)
• _async_gen_next_batch(): optional weight sync + non-blocking rollout
• _async_get_teacher_knowledge(): submit teacher requests, store future
├─ For each step:
• Sync rollout weights
• Retrieve (batch, gen_output, teacher_output) from futures
• Merge gen + teacher outputs → DataProto
• Compute metrics (response length stats, timing, throughput)
• Update actor (forward_backward_batch + KL loss + optimizer step)
• (Optional) save checkpoint
Note: Schedulers are optional and explained later; the distillation objective is independent of how phases are overlapped.
OnPolicyDistillTrainer (ray_trainer.py)GenerationBatchFuture objects holding rollout and (later) teacher futures.OnPolicyDistillActor.update_policy() orchestrates micro-batch forward/backward.logits_processor during forward on pipeline last stage.init_model builds model; no optimizer).async_generate_sequences returns a Ray future for overlapping.teacher/)TeacherClient.submit() returns a Future; aggregator composes micro-batches.megatron_kl_loss.py)on_policy_distill_trainer.yaml)| Section | Purpose | Notable Keys |
|---|---|---|
| actor_rollout_ref.teacher | Teacher server | server_ip, server_port, n_server_workers |
| trainer | Global training control | total_epochs, save_freq, scheduler (one_step_off |
| rollout | Resource split for rollout | n_gpus_per_node, nnodes |
Remember to set trainer.n_gpus_per_node, trainer.nnodes, rollout.n_gpus_per_node and rollout.nnodes to allocate GPU resources.
Enable by:
actor_rollout_ref.actor.use_dynamic_bsz=True
actor_rollout_ref.actor.max_token_len=6000 # cap post-group token length
Improves utilization under variable sequence lengths.
trainer.nnodes * trainer.n_gpus_per_node GPUs.rollout.nnodes * rollout.n_gpus_per_node GPUs.n_server_workers to avoid stalls (monitor wait_prev_teacher).Before training process, you should have a teacher server to provide logp information.
We provide a toy teacher server example with vLLM. It needs telnet to check proxy status, and python command to run. So if you have not installed telnet, you can just delete these code in start_server.sh. And some OS use python3 rather than python, so you also need to modify it. Also you can change the port of teacher if you meet port conflict.
There are 3 arguments can be set for vllm backend --tp-size, --n-logprobs and --ckpt-path in start_server.sh / worker.py. You should set before you start server.
We also provide a toy multi-node teacher server. You can start the main node using start_server.sh and start the slave nodes using join_server.sh. Still remember to set args in join_server.sh, especially the $PROXY_IP and $PROXY_BACKEND_PORT of main node.
When training, student will automatically use the teacher's topk (n-logprobs) to set its own topk argument at line 83 of recipe/gkd/megatron_kl_loss.py, so you don't need to set student's topk argument.
cd recipe/gkd/teacher
bash start_server.sh
# Exports ports and launches proxy + worker (default vLLM backend)
Verify with:
telnet localhost 15555
python3 -m recipe.gkd.main_gkd \
--config-path=recipe/gkd/config \
--config-name=on_policy_distill_trainer \
actor_rollout_ref.model.path=/path/to/MODEL \
data.train_files=/path/to/train.parquet \
trainer.total_epochs=2 \
trainer.n_gpus_per_node=4 rollout.n_gpus_per_node=2 \
actor_rollout_ref.teacher.server_ip=127.0.0.1 \
actor_rollout_ref.teacher.server_port=15555 \
trainer.scheduler=one_step_off
(Requires a running teacher server).
See run_moonlight_dsv3_training.sh for a full script including:
dist_checkpointing_path)Submit (after adjusting paths):
bash recipe/gkd/run_moonlight_dsv3_training.sh
Emitted metrics include (prefixes may vary):
timing/wait_prev_gen, timing/sync_rollout_weights, timing/get_teacher_knowledge, timing/update_actor.response_seq_len/* (avg, max, min, counts).perf/mfu/actor, perf/max_memory_allocated_gb, perf/cpu_memory_used_gb.actor/kl_loss, actor/grad_norm, actor/lr.Interpretation Tips:
wait_prev_teacher → scale n_server_workers and allocate more teacher GPUs or reduce per-request batch size, or just use two_step_off.wait_prev_gen with uniform lengths → allocate more rollout GPUs.sync_rollout_weights → check NCCL env / network congestion and try to modify actor_rollout_ref.rollout.update_weights_bucket_megabytes.(epoch, batch, gen_output, teacher_output, timing_dict).teacher_utils.get_teacher_knowledge and modifying logits_processor.| Category | Supported |
|---|---|
| Train engine | Megatron |
| Rollout engine | vLLM |
| Distillation signal | Teacher top-k logprobs & indices |
| Scheduling | one_step_off, two_step_off |
telnet <ip> <port>).actor_rollout_ref.model.path contains the correct Megatron/HF config artifacts.train_files points to a parquet dataset compatible with this recipe's dataset loader.config/runtime_env.yaml).Feel free to open issues or PRs to extend scheduler variants, add new distillation objectives, or broaden engine support, and more improvement.