Back to Trl

Distillation Trainer

docs/source/distillation_trainer.md

1.3.07.2 KB
Original Source

Distillation Trainer

Overview

The Distillation Trainer implements on-policy knowledge distillation as described in On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes by Rishabh Agarwal, Nino Vieillard, Yongchao Zhou, Piotr Stanczyk, Sabela Ramos, Matthieu Geist, and Olivier Bachem.

Knowledge distillation (KD) is widely used for compressing a teacher model to reduce its inference cost and memory footprint, by training a smaller student model. However, current KD methods for auto-regressive sequence models suffer from distribution mismatch between output sequences seen during training and those generated by the student during inference. To address this issue, we introduce Generalized Knowledge Distillation (GKD). Instead of solely relying on a fixed set of output sequences, GKD trains the student on its self-generated output sequences by leveraging feedback from the teacher on such sequences. Unlike supervised KD approaches, GKD also offers the flexibility to employ alternative loss functions between the student and teacher, which can be useful when the student lacks the expressivity to mimic the teacher's distribution.

The DistillationTrainer is designed for distilling teacher models of all sizes into smaller students efficiently. It extends the ideas from the GKDTrainer with three key optimizations:

  1. Generation buffer – decouples the training microbatch size from the generation batch size, letting vLLM batch many prompts in a single call across gradient accumulation steps. This alone can speed up training by up to 40x.
  2. Teacher server support – moves the teacher to an external vLLM server so it does not need to fit on the same GPUs as the student.
  3. Binary-encoded logprob payloads – packs log-probabilities into base64-encoded NumPy arrays instead of nested JSON lists, shrinking transfer payloads by ~5x.

[!NOTE] The Distillation Trainer is currently part of the trl.experimental namespace. APIs may change without notice while the feature is iterated on.

Quick start

python
from datasets import load_dataset
from trl.experimental.distillation import DistillationConfig, DistillationTrainer

# 1. Load dataset and format as prompt-only chat messages
dataset = load_dataset("openai/gsm8k", "main", split="train")
dataset = dataset.map(
    lambda x: {"messages": [{"role": "user", "content": x["question"]}]},
    remove_columns=dataset.column_names,
)

# 2. Configure distillation
config = DistillationConfig(
    output_dir="results/distill-qwen-gsm8k",
    num_train_epochs=1,
    bf16=True,
    save_strategy="no",
    # Distillation
    lmbda=1.0,                      # fully on-policy (student generates)
    beta=1.0,                       # reverse KL
    # Teacher
    teacher_model_init_kwargs={"torch_dtype": "bfloat16"},
)

# 3. Train
trainer = DistillationTrainer(
    model="Qwen/Qwen2.5-1.5B-Instruct",
    teacher_model="Qwen/Qwen2.5-7B-Instruct",
    args=config,
    train_dataset=dataset,
)
trainer.train()
trainer.save_model()

Usage tips

The [experimental.distillation.DistillationTrainer] needs three key parameters set via [experimental.distillation.DistillationConfig]:

  • lmbda: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. When lmbda=0.0, training is fully off-policy (dataset completions only). When lmbda=1.0, training is fully on-policy (student generates all completions). For values in between, each gradient accumulation slice is randomly assigned as on- or off-policy based on lmbda.
  • beta: controls the interpolation in the Generalized Jensen-Shannon Divergence. When beta=0.0 the loss approximates forward KL divergence, while beta=1.0 approximates reverse KL divergence. Values in between interpolate.
  • loss_top_k: number of top tokens to use for the KL/JSD loss. Set to 0 for exact full-vocabulary computation (local teacher only), or > 0 for a top-k approximation. See more about top-k with external teacher server below.

On-policy vs. off-policy

Setting lmbda=1.0 (fully on-policy) generally outperforms off-policy distillation because the student learns from its own mistakes rather than imitating trajectories it may never produce. The generation buffer ensures on-policy training stays efficient: prompts across gradient accumulation steps are batched into a single vLLM call.

Using an external teacher server

For teachers that do not fit on training GPUs (e.g., 100B+ parameters), host the teacher on a separate vLLM server and set use_teacher_server=True with teacher_model_server_url:

python
config = DistillationConfig(
    output_dir="distilled-model",
    use_teacher_server=True,
    teacher_model_server_url="http://teacher-host:8000",
    loss_top_k=1,       # required with teacher server when beta > 0
    beta=1.0,
    lmbda=1.0,
)

trainer = DistillationTrainer(
    model="Qwen/Qwen3-4B",
    args=config,
    train_dataset=dataset,
)
trainer.train()

When using the teacher server:

  • loss_top_k must be > 0 when beta=0.0 (forward KL)
  • loss_top_k must be exactly 1 when beta > 0 (reverse KL or JSD)
  • reverse_kl_top_1_mode="argmax" is not supported
  • Liger kernel is not supported

Expected dataset type

The dataset should be formatted as a conversational language modeling dataset:

python
{"messages": [{"role": "user", "content": "What color is the sky?"},
              {"role": "assistant", "content": "It is blue."}]}

When using fully on-policy distillation (lmbda=1.0), the assistant turn can be omitted since the student will generate its own completions:

python
{"messages": [{"role": "user", "content": "What color is the sky?"}]}

Example script

Use trl/experimental/distillation/distillation.py to launch distillation training from the command line. The script supports full training, mixed on/off-policy, and LoRA via the standard ModelConfig flags.

bash
# Full training (off-policy only, lmbda=0):
python trl/experimental/distillation/distillation.py \
    --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
    --teacher_model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
    --dataset_name trl-lib/chatbot_arena_completions \
    --learning_rate 2e-5 \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 8 \
    --lmbda 0.0 \
    --output_dir distilled-model \
    --num_train_epochs 1
bash
# Mixed on/off-policy (lmbda=0.5):
python trl/experimental/distillation/distillation.py \
    --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
    --teacher_model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
    --dataset_name trl-lib/chatbot_arena_completions \
    --learning_rate 2e-5 \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 8 \
    --lmbda 0.5 \
    --beta 0.5 \
    --output_dir distilled-model \
    --num_train_epochs 1

DistillationTrainer

[[autodoc]] experimental.distillation.DistillationTrainer - train - save_model - push_to_hub

DistillationConfig

[[autodoc]] experimental.distillation.DistillationConfig