docs/source/sft_trainer.md
TRL supports the Supervised Fine-Tuning (SFT) Trainer for training language models.
This post-training method was contributed by Younes Belkada.
This example demonstrates how to train a language model using the [SFTTrainer] from TRL. We train a Qwen 3 0.6B model on the Capybara dataset, a compact, diverse multi-turn dataset to benchmark reasoning and generalization.
from trl import SFTTrainer
from datasets import load_dataset
trainer = SFTTrainer(
model="Qwen/Qwen3-0.6B",
train_dataset=load_dataset("trl-lib/Capybara", split="train"),
)
trainer.train()
SFT supports both language modeling and prompt-completion datasets. The [SFTTrainer] is compatible with both standard and conversational dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
# Standard language modeling
{"text": "The sky is blue."}
# Conversational language modeling
{"messages": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."}]}
# Standard prompt-completion
{"prompt": "The sky is",
"completion": " blue."}
# Conversational prompt-completion
{"prompt": [{"role": "user", "content": "What color is the sky?"}],
"completion": [{"role": "assistant", "content": "It is blue."}]}
If your dataset is not in one of these formats, you can preprocess it to convert it into the expected format. Here is an example with the FreedomIntelligence/medical-o1-reasoning-SFT dataset:
from datasets import load_dataset
dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", "en")
def preprocess_function(example):
return {
"prompt": [{"role": "user", "content": example["Question"]}],
"completion": [
{"role": "assistant", "content": f"<think>{example['Complex_CoT']}</think>{example['Response']}"}
],
}
dataset = dataset.map(preprocess_function, remove_columns=["Question", "Response", "Complex_CoT"])
print(next(iter(dataset["train"])))
{
"prompt": [
{
"content": "Given the symptoms of sudden weakness in the left arm and leg, recent long-distance travel, and the presence of swollen and tender right lower leg, what specific cardiac abnormality is most likely to be found upon further evaluation that could explain these findings?",
"role": "user",
}
],
"completion": [
{
"content": "<think>Okay, let's see what's going on here. We've got sudden weakness [...] clicks into place!</think>The specific cardiac abnormality most likely to be found in [...] the presence of a PFO facilitating a paradoxical embolism.",
"role": "assistant",
}
],
}
Supervised Fine-Tuning (SFT) is the simplest and most commonly used method to adapt a language model to a target dataset. The model is trained in a fully supervised fashion using pairs of input and output sequences. The goal is to minimize the negative log-likelihood (NLL) of the target sequence, conditioning on the input.
This section breaks down how SFT works in practice, covering the key steps: preprocessing, tokenization and loss computation.
During training, each example is expected to contain a text field or a (prompt, completion) pair, depending on the dataset format. For more details on the expected formats, see Dataset formats.
The [SFTTrainer] tokenizes each input using the model's tokenizer. If both prompt and completion are provided separately, they are concatenated before tokenization.
The loss used in SFT is the token-level cross-entropy loss, defined as:
$$ \mathcal{L}{\text{SFT}}(\theta) = - \sum{t=1}^{T} \log p_\theta(y_t \mid y_{<t}), $$
where \( y_t \) is the target token at timestep \( t \), and the model is trained to predict the next token given the previous ones. In practice, padding tokens are masked out during loss computation.
[!TIP] The paper On the Generalization of SFT: A Reinforcement Learning Perspective with Reward Rectification proposes an alternative loss function, called Dynamic Fine-Tuning (DFT), which aims to improve generalization by rectifying the reward signal. This method can be enabled by setting
loss_type="dft"in the [SFTConfig]. For more details, see Paper Index - Dynamic Fine-Tuning.
During training, the loss is computed using a one-token shift: the model is trained to predict each token in the sequence based on all previous tokens. Specifically, the input sequence is shifted right by one position to form the target labels.
Padding tokens (if present) are ignored in the loss computation by applying an ignore index (default: -100) to the corresponding positions. This ensures that the loss focuses only on meaningful, non-padding tokens.
While training and evaluating we record the following reward metrics:
global_step: The total number of optimizer steps taken so far.epoch: The current epoch number, based on dataset iteration.num_tokens: The total number of tokens processed so far.loss: The average cross-entropy loss computed over non-masked tokens in the current logging interval.entropy: The average entropy of the model's predicted token distribution over non-masked tokens.mean_token_accuracy: The proportion of non-masked tokens for which the model’s top-1 prediction matches the ground truth token.learning_rate: The current learning rate, which may change dynamically if a scheduler is used.grad_norm: The L2 norm of the gradients, computed before gradient clipping.You can directly pass the kwargs of the [~transformers.AutoModelForCausalLM.from_pretrained()] method to the [SFTConfig]. For example, if you want to load a model in a different precision, analogous to
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", dtype=torch.bfloat16)
you can do so by passing the model_init_kwargs={"dtype": torch.bfloat16} argument to the [SFTConfig].
from trl import SFTConfig
training_args = SFTConfig(
model_init_kwargs={"dtype": torch.bfloat16},
)
Note that all keyword arguments of [~transformers.AutoModelForCausalLM.from_pretrained()] are supported.
[SFTTrainer] supports example packing, where multiple examples are packed in the same input sequence to increase training efficiency. To enable packing, simply pass packing=True to the [SFTConfig] constructor.
training_args = SFTConfig(packing=True)
For more details on packing, see Packing.
To train on assistant messages only, use a conversational dataset and set assistant_only_loss=True in the [SFTConfig]. This setting ensures that loss is computed only on the assistant responses, ignoring user or system messages.
training_args = SFTConfig(assistant_only_loss=True)
[!WARNING] This functionality requires the chat template to include
{% generation %}and{% endgeneration %}keywords. For known model families (e.g. Qwen3), TRL automatically patches the template whenassistant_only_loss=True. See Chat Templates for the full list of bundled training templates. For other models, check that your chat template includes these keywords. See HuggingFaceTB/SmolLM3-3B for an example.
To train on completion only, use a prompt-completion dataset. By default, the trainer computes the loss on the completion tokens only, ignoring the prompt tokens. If you want to train on the full sequence, set completion_only_loss=False in the [SFTConfig].
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset
# Load a prompt-completion dataset; loss is computed on the completion only by default
dataset = load_dataset("trl-lib/kto-mix-14k", split="train")
trainer = SFTTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
args=SFTConfig(completion_only_loss=True), # True by default for prompt-completion datasets
train_dataset=dataset,
)
trainer.train()
[!TIP] Training on completion only is compatible with training on assistant messages only. In this case, use a conversational prompt-completion dataset and set
assistant_only_loss=Truein the [SFTConfig].
We support tight integration with 🤗 PEFT library, allowing any user to conveniently train adapters and share them on the Hub, rather than training the entire model.
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
dataset = load_dataset("trl-lib/Capybara", split="train")
trainer = SFTTrainer(
"Qwen/Qwen3-0.6B",
train_dataset=dataset,
peft_config=LoraConfig(),
)
trainer.train()
You can also continue training your [~peft.PeftModel]. For that, first load a PeftModel outside [SFTTrainer] and pass it directly to the trainer without the peft_config argument being passed.
from datasets import load_dataset
from trl import SFTTrainer
from peft import AutoPeftModelForCausalLM
model = AutoPeftModelForCausalLM.from_pretrained("trl-lib/Qwen3-4B-LoRA", is_trainable=True)
dataset = load_dataset("trl-lib/Capybara", split="train")
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
)
trainer.train()
[!TIP] When training adapters, you typically use a higher learning rate (≈1e‑4) since only new parameters are being learned.
pythonSFTConfig(learning_rate=1e-4, ...)
Liger Kernel is a collection of Triton kernels for LLM training that boosts multi-GPU throughput by 20%, cuts memory use by 60% (enabling up to 4× longer context), and works seamlessly with tools like FlashAttention, PyTorch FSDP, and DeepSpeed. For more information, see Liger Kernel Integration.
RapidFire AI is an open-source experimentation engine that sits on top of TRL and lets you launch multiple SFT configurations at once, even on a single GPU. Instead of trying configurations sequentially, RapidFire lets you see all their learning curves earlier, stop underperforming runs, and clone promising ones with new settings in flight without restarting. For more information, see RapidFire AI Integration.
Unsloth is an open‑source framework for fine‑tuning and reinforcement learning that trains LLMs (like Llama, Mistral, Gemma, DeepSeek, and more) up to 2× faster with up to 70% less VRAM, while providing a streamlined, Hugging Face–compatible workflow for training, evaluation, and deployment. For more information, see Unsloth Integration.
Instruction tuning teaches a base language model to follow user instructions and engage in conversations. This requires:
This example shows how to transform the Qwen 3 0.6B Base model into an instruction-following model using the Capybara dataset and a chat template from HuggingFaceTB/SmolLM3-3B. The SFT Trainer automatically handles tokenizer updates and special token configuration.
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset
trainer = SFTTrainer(
model="Qwen/Qwen3-0.6B-Base",
args=SFTConfig(
output_dir="Qwen3-0.6B-Instruct",
chat_template_path="HuggingFaceTB/SmolLM3-3B",
),
train_dataset=load_dataset("trl-lib/Capybara", split="train"),
)
trainer.train()
[!WARNING] Some base models, like those from Qwen, have a predefined chat template in the model's tokenizer. In these cases, it is not necessary to apply [
clone_chat_template()], as the tokenizer already handles the formatting. However, it is necessary to align the EOS token with the chat template to ensure the model's responses terminate correctly. In these cases, specifyeos_tokenin [SFTConfig]; for example, forQwen/Qwen2.5-1.5B, one should seteos_token="<|im_end|>".
Once trained, your model can now follow instructions and engage in conversations using its new chat template.
>>> from transformers import pipeline
>>> pipe = pipeline("text-generation", model="Qwen3-0.6B-Instruct/checkpoint-5000")
>>> prompt = "<|im_start|>user\nWhat is the capital of France? Answer in one word.<|im_end|>\n<|im_start|>assistant\n"
>>> response = pipe(prompt)
>>> response[0]["generated_text"]
'<|im_start|>user\nWhat is the capital of France? Answer in one word.<|im_end|>\n<|im_start|>assistant\nThe capital of France is Paris.'
Alternatively, use the structured conversation format (recommended):
>>> prompt = [{"role": "user", "content": "What is the capital of France? Answer in one word."}]
>>> response = pipe(prompt)
>>> response[0]["generated_text"]
[{'role': 'user', 'content': 'What is the capital of France? Answer in one word.'}, {'role': 'assistant', 'content': 'The capital of France is Paris.'}]
The [SFTTrainer] fully supports fine-tuning models with tool calling capabilities. In this case, each dataset example should include:
tool_calls) and tool responses (tool role messages)tools column, typically provided as JSON schemasFor details on the expected dataset structure, see the Dataset Format — Tool Calling section.
[SFTTrainer] fully supports training Vision-Language Models (VLMs). To train a VLM, provide a dataset with either an image column (single image per sample) or an images column (list of images per sample). For more information on the expected dataset structure, see the Dataset Format — Vision Dataset section.
An example of such a dataset is the LLaVA Instruct Mix.
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset
trainer = SFTTrainer(
model="Qwen/Qwen2.5-VL-3B-Instruct",
args=SFTConfig(max_length=None),
train_dataset=load_dataset("trl-lib/llava-instruct-mix", split="train"),
)
trainer.train()
[!TIP] For VLMs, truncating may remove image tokens, leading to errors during training. To avoid this, set
max_length=Nonein the [SFTConfig]. This allows the model to process the full sequence length without truncating image tokens.pythonSFTConfig(max_length=None, ...)Only use
max_lengthwhen you've verified that truncation won't remove image tokens for the entire dataset.
[[autodoc]] SFTTrainer - train - save_model - push_to_hub
[[autodoc]] SFTConfig