docs/source/gkd_trainer.md
Generalized Knowledge Distillation (GKD) was proposed in On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes by Rishabh Agarwal, Nino Vieillard, Yongchao Zhou, Piotr Stanczyk, Sabela Ramos, Matthieu Geist, and Olivier Bachem.
The abstract from the paper is the following:
Knowledge distillation (KD) is widely used for compressing a teacher model to reduce its inference cost and memory footprint, by training a smaller student model. However, current KD methods for auto-regressive sequence models suffer from distribution mismatch between output sequences seen during training and those generated by the student during inference. To address this issue, we introduce Generalized Knowledge Distillation (GKD). Instead of solely relying on a fixed set of output sequences, GKD trains the student on its self-generated output sequences by leveraging feedback from the teacher on such sequences. Unlike supervised KD approaches, GKD also offers the flexibility to employ alternative loss functions between the student and teacher, which can be useful when the student lacks the expressivity to mimic the teacher's distribution. Furthermore, GKD facilitates the seamless integration of distillation with RL fine-tuning (RLHF). We demonstrate the efficacy of GKD for distilling auto-regressive language models on summarization, translation, and arithmetic reasoning tasks, and task-agnostic distillation for instruction-tuning.
The key aspects of GKD are:
This post-training method was contributed by Kashif Rasul and Lewis Tunstall.
The [experimental.gkd.GKDTrainer] is a wrapper around the [SFTTrainer] class that takes in a teacher model argument. It needs three parameters to be set via the [experimental.gkd.GKDConfig] namely:
lmbda: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. When lmbda=0.0, the loss reduces to supervised JSD where the student is trained with the token-level probabilities of the teacher. When lmbda=1.0, the loss reduces to on-policy JSD, where the student generates output sequences and token-specific feedback on these sequences from the teacher. For values in between [0, 1] it is random between the two based on the lmbda value for each batch.seq_kd: controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated out). When seq_kd=True and lmbda=0.0, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher.beta: controls the interpolation in the generalized Jensen-Shannon Divergence. When beta=0.0 the loss approximates forward KL divergence, while for beta=1.0 the loss approximates reverse KL divergence. For values in between [0, 1] it interpolates between the two.The authors find that on-policy data (high lmbda) performs better and the optimal beta varied depending on the task and evaluation method.
[!WARNING] Make sure that
attn_implementation="kernels-community/flash-attn2"when training Gemma models. Otherwise you will encounter NaNs in the logits due to the soft capping technique adopted by this architecture.
The basic API is as follows:
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl.experimental.gkd import GKDConfig, GKDTrainer
NUM_DUMMY_SAMPLES = 100
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# The model to optimise
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# The teacher model to calculate the KL divergence against
teacher_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
train_dataset = Dataset.from_dict(
{
"messages": [
[
{"role": "user", "content": "Hi, how are you?"},
{"role": "assistant", "content": "I'm great thanks"},
]
]
* NUM_DUMMY_SAMPLES
}
)
eval_dataset = Dataset.from_dict(
{
"messages": [
[
{"role": "user", "content": "What colour is the sky?"},
{"role": "assistant", "content": "The sky is blue"},
]
]
* NUM_DUMMY_SAMPLES
}
)
training_args = GKDConfig(output_dir="gkd-model", per_device_train_batch_size=1)
trainer = GKDTrainer(
model=model,
teacher_model=teacher_model,
args=training_args,
processing_class=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
The dataset should be formatted as a list of "messages" where each message is a list of dictionaries with the following keys:
role: either system, assistant or usercontent: the message content[[autodoc]] experimental.gkd.GKDTrainer - train - save_model - push_to_hub
[[autodoc]] experimental.gkd.GKDConfig