docs/source/gold_trainer.md
General Online Logit Distillation (GOLD) is an extension of Universal Logit Distillation (ULD) that supports student/teacher pairs with different tokenizers. It aligns the textual spans produced by both tokenizers and merges the associated logits so no completion tokens are dropped. This enables cross-tokenizer knowledge distillation, including mixed model families (for example, LLaMA students with Qwen teachers).
Key capabilities:
uld_use_hybrid_loss is enabled, GOLD compares exact vocabulary matches directly and falls back to the original sorted-probability ULD loss for unmatched tokens. This improves stability for students whose vocabularies only partially overlap with the teacher.experimental.gkd.GKDTrainer], so you can combine sequence-level KD, generalized JSD, and cross-tokenizer distillation in a single training run.[!NOTE] GOLD is currently part of the
trl.experimentalnamespace. APIs may change without notice while the feature is iterated on.
The [GOLDTrainer] subclasses [SFTTrainer] and accepts the same datasets as other TRL trainers (lists of ChatML style
messages). Important configuration flags on [GOLDConfig] include:
use_uld_loss – toggles Universal Logit Distillation. Set this to True for cross-tokenizer setups.teacher_tokenizer_name_or_path – required when use_uld_loss=True; GOLD uses the teacher tokenizer to align tokens.uld_use_hybrid_loss, uld_hybrid_matched_weight, uld_hybrid_unmatched_weight – enables and weights the hybrid
matched/unmatched loss.beta, lmbda, seq_kd – inherited from [experimental.gkd.GKDConfig], controlling the generalized JSD interpolation and on-policy
sampling ratio.num_generations, generation_batch_size – control buffered rollout generation across gradient accumulation windows.
generation_batch_size is the number of unique prompts per worker per optimizer step.model_revision – controls which student model revision GOLD loads for training and generation.A minimal end-to-end example:
from datasets import load_dataset
from trl.experimental.gold import GOLDConfig, GOLDTrainer
train_dataset = load_dataset(
"HuggingFaceTB/OpenR1-Math-220k-default-verified",
"all",
split="train[:1024]",
)
trainer = GOLDTrainer(
model="meta-llama/Llama-3.2-1B-Instruct",
teacher_model="Qwen/Qwen2.5-0.5B-Instruct",
args=GOLDConfig(output_dir="gold-model", use_uld_loss=True, teacher_tokenizer_name_or_path="Qwen/Qwen2.5-0.5B-Instruct"),
train_dataset=train_dataset,
)
trainer.train()
For quick-start workflows you can rely on string identifiers as shown above—the trainer will load the model and tokenizer for you. Explicitly instantiating AutoModelForCausalLM, AutoTokenizer, or populating GOLDConfig is recommended only for advanced use cases where you need fine-grained control over initialization.
A more explicit setup might look like this when you need to customise model loading, tokenizer settings, or training arguments:
from datasets import load_dataset
from trl import GOLDConfig, GOLDTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
student_name = "meta-llama/Llama-3.2-1B-Instruct"
teacher_name = "Qwen/Qwen2.5-0.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(student_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(student_name)
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_name)
train_dataset = load_dataset(
"HuggingFaceTB/Countdown-Task-GOLD",
"verified_Qwen2.5-0.5B-Instruct",
split="train",
)
training_args = GOLDConfig(
output_dir="gold-model",
per_device_train_batch_size=1,
teacher_model_name_or_path=teacher_name,
teacher_tokenizer_name_or_path=teacher_name,
use_uld_loss=True,
uld_use_hybrid_loss=True,
)
trainer = GOLDTrainer(
model=model,
teacher_model=teacher_model,
args=training_args,
processing_class=tokenizer,
train_dataset=train_dataset,
)
trainer.train()
[!NOTE] GOLD buffers one full optimizer-window generation batch (
per_device_train_batch_size * gradient_accumulation_steps) and reuses it across accumulation steps. If the final batch is undersized, GOLD warns and drops that last batch (Dropping last batch due to unexpected batch size). Setdataloader_drop_last=Trueto avoid this warning.
GOLD requires a conversational language modeling dataset, e.g.:
{"messages": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."}]}
GOLDTrainer keeps the raw messages so the ChatML collator can construct prompts and completions with the correct
boundaries.
When student and teacher use different tokenizers, the same text may be split differently:
"Hugging Face" → 1 token"Hugging", " Face" → 2 tokensGOLD aligns these sequences and merges the teacher's multi-token probabilities into a single distribution that can be compared with the student's single-token distribution.
For a teacher sequence of tokens [token₀, token₁, ..., tokenₖ] that maps to a single student token, GOLD computes:
P_merged(y) = P(y | context) × P(token₁ | token₀, context) × ... × P(tokenₖ | ..., context)
where:
P(y | context) is the marginal probability distribution over all vocabulary tokens at the first positionP(tokenᵢ | ..., context) are scalar conditional probabilities of the actual tokens that were generatedKey insight: Only the conditional probabilities of the actual continuation tokens are extracted as scalars. The full marginal distribution at the first position is then scaled by multiplying these scalar probabilities.
This ensures:
Given:
P(x₀): ["HF": 0.6, "is": 0.3, "cool": 0.1]
P(x₁ | "HF"): ["HF": 0.05, "is": 0.9, "cool": 0.05]
If tokens 0 and 1 are merged, and the actual sequence was ["HF", "is"]:
P_merged("HF") = 0.6 × 0.9 = 0.54 ✓ (correct joint probability)
P_merged("is") = 0.3 × 0.9 = 0.27
P_merged("cool") = 0.1 × 0.9 = 0.09
The merged distribution is unnormalized (sums to 0.81), but this is intentional and correct for ULD loss computation, which uses sorting and L1 distance.
Use trl/experimental/gold/gold.py to launch GOLD training from the command line. The script supports full training and LoRA via the standard ModelConfig flags.
python trl/experimental/gold/gold.py \
--model_name_or_path meta-llama/Llama-3.2-1B-Instruct \
--teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct \
--dataset_name trl-lib/chatbot_arena_completions \
--learning_rate 2e-5 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 8 \
--output_dir gold-model \
--num_train_epochs 1 \
--push_to_hub
[[autodoc]] experimental.gold.GOLDTrainer - train - generate_on_policy_outputs - save_model - push_to_hub
[[autodoc]] experimental.gold.GOLDConfig