Back to Trl

Teaching Tool Calling with Supervised Fine-Tuning (SFT) using TRL on a Free Colab Notebook

examples/notebooks/sft_tool_calling.ipynb

1.3.012.8 KB
Original Source

Teaching Tool Calling with Supervised Fine-Tuning (SFT) using TRL on a Free Colab Notebook

Learn how to teach a language model to perform tool calling using Supervised Fine-Tuning (SFT) with LoRA/QLoRA and the TRL library.

The model used in this notebook does not have native tool-calling support. We extend its Jinja2 chat template (via tiny_aya_chat_template.jinja) to serialize tool schemas into the system preamble and render tool calls as structured <tool_call> XML inside the model's native <|START_RESPONSE|> / <|END_RESPONSE|> delimiters. The modified template is saved with the tokenizer, making inference reproducible: just load the tokenizer from the output directory and call apply_chat_template with tools=TOOLS.

Key concepts

  • SFT: Trains a model on example input-output pairs to align its behavior with a desired task.
  • Tool Calling: The ability of a model to respond with a structured function call instead of free-form text.
  • LoRA: Updates only a small set of low-rank parameters, reducing training cost and memory usage.
  • QLoRA: A quantized variant of LoRA that enables fine-tuning larger models on limited hardware.
  • TRL: The Hugging Face library that makes fine-tuning and reinforcement learning simple and efficient.

Install dependencies

We'll install TRL with the PEFT extra, which brings in all main dependencies such as Transformers and PEFT (parameter-efficient fine-tuning). We also install trackio for experiment logging, and bitsandbytes for 4-bit quantization,

python
!pip install -Uq "trl[peft]" trackio bitsandbytes liger-kernel

Log in to Hugging Face

Log in to your Hugging Face account to push the fine-tuned model to the Hub and access gated models. You can find your access token on your account settings page.

python
from huggingface_hub import notebook_login

notebook_login()

Load Dataset

We load the bebechien/SimpleToolCalling dataset, which contains user queries paired with the correct tool call to handle each request. Each sample provides a user_content, a tool_name, and tool_arguments.

python
from datasets import load_dataset

dataset_name = "bebechien/SimpleToolCalling"
dataset = load_dataset(dataset_name, split="train")
python
dataset

Prepare Tool-Calling Data

We define two tools: search_knowledge_base for internal company documents and search_google for public information. We then write a custom Jinja2 chat template that extends the model's default template with two additions:

  1. A Tool Use section is appended to the system preamble when tools is passed to apply_chat_template.
  2. Assistant turns with tool_calls render the call as structured <tool_call> inside the model's existing <|START_RESPONSE|> / <|END_RESPONSE|> delimiters.

Each training sample uses the standard tool_calls message format with a tools key — SFTTrainer passes these to apply_chat_template automatically.

python
import json

# These are the tool schemas that are used in the dataset
TOOLS = [
    {
        "type": "function",
        "function": {
            "name": "search_knowledge_base",
            "description": "Search internal company documents, policies and project data.",
            "parameters": {
                "type": "object",
                "properties": {"query": {"type": "string", "description": "query string"}},
                "required": ["query"],
            },
            "return": {"type": "string"},
        },
    },
    {
        "type": "function",
        "function": {
            "name": "search_google",
            "description": "Search public information.",
            "parameters": {
                "type": "object",
                "properties": {"query": {"type": "string", "description": "query string"}},
                "required": ["query"],
            },
            "return": {"type": "string"},
        },
    },
]

def create_conversation(sample):
    return {
        "prompt": [{"role": "user", "content": sample["user_content"]}],
        "completion": [
            {
                "role": "assistant",
                "tool_calls": [
                    {
                        "type": "function",
                        "function": {
                            "name": sample["tool_name"],
                            "arguments": json.loads(sample["tool_arguments"]),
                        },
                    }
                ],
            },
        ],
        "tools": TOOLS,
    }
python
dataset = dataset.map(create_conversation, remove_columns=dataset.features)

# Split dataset into 50% training samples and 50% test samples
dataset = dataset.train_test_split(test_size=0.5, shuffle=True)

Let's inspect an example from the training set to verify the format:

python
dataset['train'][0]
python
dataset

Load Model and Configure LoRA/QLoRA

Choose the model you want to fine-tune. This notebook uses CohereLabs/tiny-aya-global by default.

python
model_id, output_dir = "CohereLabs/tiny-aya-global", "tiny-aya-global-SFT"     # ✅ ~9.1 GB VRAM

Load the model with 4-bit quantization using BitsAndBytesConfig (QLoRA). To use standard LoRA without quantization, comment out the quantization_config parameter. We also load the tokenizer separately so we can install the custom chat template before training.

python
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    attn_implementation="sdpa",                   # Change to Flash Attention if GPU has support
    dtype=torch.float16,                          # Change to bfloat16 if GPU has support
    use_cache=True,                               # Whether to cache attention outputs to speed up inference
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,                        # Load the model in 4-bit precision to save memory
        bnb_4bit_compute_dtype=torch.float16,     # Data type used for internal computations in quantization
        bnb_4bit_use_double_quant=True,           # Use double quantization to improve accuracy
        bnb_4bit_quant_type="nf4"                 # Type of quantization. "nf4" is recommended for recent LLMs
    )
)
python
!wget https://raw.githubusercontent.com/huggingface/trl/refs/heads/main/examples/scripts/tiny_aya_chat_template.jinja

Configure LoRA. Instead of updating the model's original weights, we fine-tune a lightweight LoRA adapter. The target_modules specify which layers receive the adapter — update these if using a different model architecture.

python
from peft import LoraConfig

# You may need to update `target_modules` depending on the architecture of your chosen model.
# For example, different LLMs might have different attention/projection layer names.
peft_config = LoraConfig(
    r=32,
    lora_alpha=32,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj",],
)

Train Model

Configure the training run with SFTConfig. The settings below are tuned for low memory usage. For full details on available parameters, see the TRL SFTConfig documentation.

python
from trl import SFTConfig

training_args = SFTConfig(
    # Training schedule / optimization
    per_device_train_batch_size = 1,      # Batch size per GPU
    gradient_accumulation_steps = 4,      # Effective batch size = 1 * 4 = 4
    warmup_steps = 5,
    learning_rate = 2e-4,                 # Learning rate for the optimizer
    optim = "paged_adamw_8bit",           # Optimizer
    chat_template_path= "tiny_aya_chat_template.jinja",  # Use the tool-aware chat template

    # Logging / reporting
    logging_steps=1,                      # Log training metrics every N steps
    report_to="trackio",                  # Experiment tracking tool
    trackio_space_id=output_dir,          # HF Space where the experiment tracking will be saved
    output_dir=output_dir,                # Where to save model checkpoints and logs

    max_length=1024,                      # Maximum input sequence length
    activation_offloading=True,           # Offload activations to CPU to reduce GPU memory usage

    # Hub integration
    push_to_hub=True,                     # Automatically push the trained model to the Hugging Face Hub
                                          # The model will be saved under your Hub account in the repository named `output_dir`
)
python
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset['train'],
    peft_config=peft_config
)

Show memory stats before training:

python
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)

print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

And train!

python
trainer_stats = trainer.train()

Show memory stats after training:

python
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)

print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.")
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

Save the Fine-Tuned Model

Save the trained LoRA adapter locally and push it to the Hugging Face Hub.

python
trainer.save_model(output_dir)
trainer.push_to_hub(dataset_name=dataset_name)

Load the Fine-Tuned Model and Run Inference

Load the trained LoRA adapter on top of the base model and merge it into the weights for efficient inference.

python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

# Load from output_dir to get the tokenizer with the updated chat template
tokenizer = AutoTokenizer.from_pretrained(output_dir)

base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    attn_implementation="sdpa",
    dtype=torch.float16,
    device_map="auto",
)

model = PeftModel.from_pretrained(base_model, output_dir)
model = model.merge_and_unload()
model.eval()

Define a prediction function that uses apply_chat_template with tools=TOOLS to construct the prompt. The model generates a JSON tool call inside its native response delimiters; skip_special_tokens=True strips those delimiters, leaving just the JSON string.

python
def generate_prediction(prompt):
    text = tokenizer.apply_chat_template(
        prompt, tools=TOOLS, tokenize=False, add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=512,
    )
    output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]
    return tokenizer.decode(output_ids, skip_special_tokens=True)

Let's test the fine-tuned model on an example from the test set:

python
sample_test_data = dataset["test"][0] # Get a sample from the test set

user_content = sample_test_data["prompt"]

print(f"User Query: {user_content}")

predicted_output = generate_prediction(user_content)
print(f"Predicted Output: {predicted_output}")

You can still use the strong multilingual model capabilities:

python
user_content = "Explica en español qué significa la palabra japonesa 'ikigai' y da un ejemplo práctico." # Spanish question
user_content = [{"role": "user", "content": user_content}]

print(f"User Query: {user_content}")

predicted_output = generate_prediction(user_content)
print(f"Predicted Output: {predicted_output}")