Back to Ultralytics

Customizing Trainer

docs/en/guides/custom-trainer.md

8.4.5024.0 KB
Original Source

Customizing Trainer

The Ultralytics training pipeline is built around BaseTrainer and task-specific trainers like DetectionTrainer. These classes handle the training loop, validation, checkpointing, and logging out of the box. When you need more control — tracking custom metrics, adjusting loss weighting, or implementing learning rate schedules — you can subclass the trainer and override specific methods.

This guide walks through seven common customizations:

  1. Logging custom metrics (F1 score) at the end of each epoch
  2. Adding class weights to handle class imbalance
  3. Saving the best model based on a different metric
  4. Freezing the backbone for the first N epochs, then unfreezing
  5. Specifying per-layer learning rates
  6. Synchronizing BatchNorm across GPUs for multi-GPU training
  7. Configuring gradient clipping for stability tuning

!!! tip "Prerequisites"

Before reading this guide, make sure you're familiar with the basics of [training YOLO models](../modes/train.md) and the [Advanced Customization](../usage/engine.md) page, which covers the `BaseTrainer` architecture.

How Custom Trainers Work

The YOLO model class accepts a trainer parameter in the train() method. This allows you to pass your own trainer class that extends the default behavior:

python
from ultralytics import YOLO
from ultralytics.models.yolo.detect import DetectionTrainer


class CustomTrainer(DetectionTrainer):
    """A custom trainer that extends DetectionTrainer with additional functionality."""

    pass  # Add your customizations here


model = YOLO("yolo26n.pt")
model.train(data="coco8.yaml", epochs=10, trainer=CustomTrainer)

Your custom trainer inherits all functionality from DetectionTrainer, so you only need to override the specific methods you want to customize.

Logging Custom Metrics

The validation step computes precision, recall, and mAP. If you need additional metrics like per-class F1 score, override validate():

python
import numpy as np

from ultralytics import YOLO
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.utils import LOGGER


class MetricsTrainer(DetectionTrainer):
    """Custom trainer that computes and logs F1 score at the end of each epoch."""

    def validate(self):
        """Run validation and compute per-class F1 scores."""
        metrics, fitness = super().validate()
        if metrics is None:
            return metrics, fitness

        if hasattr(self.validator, "metrics") and hasattr(self.validator.metrics, "box"):
            box = self.validator.metrics.box
            f1_per_class = box.f1
            class_indices = box.ap_class_index
            names = self.validator.names

            valid_f1 = f1_per_class[f1_per_class > 0]
            mean_f1 = np.mean(valid_f1) if len(valid_f1) > 0 else 0.0

            LOGGER.info(f"Mean F1 Score: {mean_f1:.4f}")
            per_class_str = [
                f"{names[i]}: {f1_per_class[j]:.3f}" for j, i in enumerate(class_indices) if f1_per_class[j] > 0
            ]
            LOGGER.info(f"Per-class F1: {per_class_str}")

        return metrics, fitness


model = YOLO("yolo26n.pt")
model.train(data="coco8.yaml", epochs=5, trainer=MetricsTrainer)

This logs the mean F1 score across all classes and a per-class breakdown after each validation run.

!!! note "Available Metrics"

The validator provides access to many metrics through `self.validator.metrics.box`:

| Attribute | Description |
|---|---|
| `f1` | F1 score per class |
| `image_metrics` | Per-image metrics dictionary with precision, recall, F1, TP, FP, and FN |
| `p` | Precision per class |
| `r` | Recall per class |
| `ap50` | AP at IoU 0.5 per class |
| `ap` | AP at IoU 0.5:0.95 per class |
| `mp`, `mr` | Mean precision and recall |
| `map50`, `map` | Mean AP metrics |

Adding Class Weights

If your dataset has imbalanced classes (e.g., a rare defect in manufacturing inspection), you can upweight underrepresented classes in the loss function. This makes the model penalize misclassifications on rare classes more heavily.

To customize the loss, subclass the loss classes, model, and trainer:

python
import torch
from torch import nn

from ultralytics import YOLO
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel
from ultralytics.utils import RANK
from ultralytics.utils.loss import E2ELoss, v8DetectionLoss


class WeightedDetectionLoss(v8DetectionLoss):
    """Detection loss with class weights applied to BCE classification loss."""

    def __init__(self, model, class_weights=None, tal_topk=10, tal_topk2=None):
        """Initialize loss with optional per-class weights for BCE."""
        super().__init__(model, tal_topk=tal_topk, tal_topk2=tal_topk2)
        if class_weights is not None:
            self.bce = nn.BCEWithLogitsLoss(
                pos_weight=class_weights.to(self.device),
                reduction="none",
            )


class WeightedE2ELoss(E2ELoss):
    """E2E Loss with class weights for YOLO26."""

    def __init__(self, model, class_weights=None):
        """Initialize E2E loss with weighted detection loss."""

        def weighted_loss_fn(model, tal_topk=10, tal_topk2=None):
            return WeightedDetectionLoss(model, class_weights=class_weights, tal_topk=tal_topk, tal_topk2=tal_topk2)

        super().__init__(model, loss_fn=weighted_loss_fn)


class WeightedDetectionModel(DetectionModel):
    """Detection model that uses class-weighted loss."""

    def init_criterion(self):
        """Initialize weighted loss criterion with per-class weights."""
        class_weights = torch.ones(self.nc)
        class_weights[0] = 2.0  # upweight class 0
        class_weights[1] = 3.0  # upweight rare class 1
        return WeightedE2ELoss(self, class_weights=class_weights)


class WeightedTrainer(DetectionTrainer):
    """Trainer that returns a WeightedDetectionModel."""

    def get_model(self, cfg=None, weights=None, verbose=True):
        """Return a WeightedDetectionModel."""
        model = WeightedDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
        if weights:
            model.load(weights)
        return model


model = YOLO("yolo26n.pt")
model.train(data="coco8.yaml", epochs=10, trainer=WeightedTrainer)

!!! tip "Computing Weights from Dataset"

You can compute class weights automatically from your dataset's label distribution. A common approach is inverse frequency weighting:

```python
import numpy as np

# class_counts: number of instances per class
class_counts = np.array([5000, 200, 3000])
# Inverse frequency: rarer classes get higher weight
class_weights = max(class_counts) / class_counts
# Result: [1.0, 25.0, 1.67]
```

Saving the Best Model by Custom Metric

The trainer saves best.pt based on fitness, which defaults to 0.9 × [email protected]:0.95 + 0.1 × [email protected]. To use a different metric (like [email protected] or recall), override validate() and return your chosen metric as the fitness value. The built-in save_model() will then use it automatically:

python
from ultralytics import YOLO
from ultralytics.models.yolo.detect import DetectionTrainer


class CustomSaveTrainer(DetectionTrainer):
    """Trainer that saves the best model based on [email protected] instead of default fitness."""

    def validate(self):
        """Override fitness to use [email protected] for best model selection."""
        metrics, fitness = super().validate()
        if metrics:
            fitness = metrics.get("metrics/mAP50(B)", fitness)
            if self.best_fitness is None or fitness > self.best_fitness:
                self.best_fitness = fitness
        return metrics, fitness


model = YOLO("yolo26n.pt")
model.train(data="coco8.yaml", epochs=20, trainer=CustomSaveTrainer)

!!! note "Available Metrics"

Common metrics available in `self.metrics` after validation include:

| Key | Description |
|---|---|
| `metrics/precision(B)` | Precision |
| `metrics/recall(B)` | Recall |
| `metrics/mAP50(B)` | mAP at IoU 0.5 |
| `metrics/mAP50-95(B)` | mAP at IoU 0.5:0.95 |

Freezing and Unfreezing the Backbone

Transfer learning workflows often benefit from freezing the pretrained backbone for the first N epochs, allowing the detection head to adapt before fine-tuning the entire network. Ultralytics provides a freeze parameter to freeze layers at the start of training, and you can use a callback to unfreeze them after N epochs:

python
from ultralytics import YOLO
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.utils import LOGGER

FREEZE_EPOCHS = 5


def unfreeze_backbone(trainer):
    """Callback to unfreeze all layers after FREEZE_EPOCHS."""
    if trainer.epoch == FREEZE_EPOCHS:
        LOGGER.info(f"Epoch {trainer.epoch}: Unfreezing all layers for fine-tuning")
        for name, param in trainer.model.named_parameters():
            if not param.requires_grad:
                param.requires_grad = True
                LOGGER.info(f"  Unfroze: {name}")
        trainer.freeze_layer_names = [".dfl"]


class FreezingTrainer(DetectionTrainer):
    """Trainer with backbone freezing for first N epochs."""

    def __init__(self, *args, **kwargs):
        """Initialize and register the unfreeze callback."""
        super().__init__(*args, **kwargs)
        self.add_callback("on_train_epoch_start", unfreeze_backbone)


model = YOLO("yolo26n.pt")
model.train(data="coco8.yaml", epochs=20, freeze=10, trainer=FreezingTrainer)

The freeze=10 parameter freezes the first 10 layers (the backbone) at training start. The on_train_epoch_start callback fires at the beginning of each epoch and unfreezes all parameters once the freeze period is complete.

!!! tip "Choosing What to Freeze"

- `freeze=10` freezes the first 10 layers (typically the backbone in YOLO architectures)
- `freeze=[0, 1, 2, 3]` freezes specific layers by index
- Higher `FREEZE_EPOCHS` values give the head more time to adapt before the backbone changes

Per-Layer Learning Rates

Different parts of the network can benefit from different learning rates. A common strategy is to use a lower learning rate for the pretrained backbone to preserve learned features, while allowing the detection head to adapt more quickly with a higher rate:

python
import torch

from ultralytics import YOLO
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.utils import LOGGER
from ultralytics.utils.torch_utils import unwrap_model


class PerLayerLRTrainer(DetectionTrainer):
    """Trainer with different learning rates for backbone and head."""

    def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
        """Build optimizer with separate learning rates for backbone and head."""
        backbone_params = []
        head_params = []

        for k, v in unwrap_model(model).named_parameters():
            if not v.requires_grad:
                continue
            is_backbone = any(k.startswith(f"model.{i}.") for i in range(10))
            if is_backbone:
                backbone_params.append(v)
            else:
                head_params.append(v)

        backbone_lr = lr * 0.1

        optimizer = torch.optim.AdamW(
            [
                {"params": backbone_params, "lr": backbone_lr, "weight_decay": decay},
                {"params": head_params, "lr": lr, "weight_decay": decay},
            ],
        )

        LOGGER.info(
            f"PerLayerLR optimizer: backbone ({len(backbone_params)} params, lr={backbone_lr}) "
            f"| head ({len(head_params)} params, lr={lr})"
        )
        return optimizer


model = YOLO("yolo26n.pt")
model.train(data="coco8.yaml", epochs=20, trainer=PerLayerLRTrainer)

RT-DETR Variant

For RT-DETR the pattern is the same with two refinements. The backbone length is read from model.yaml["backbone"] so the same trainer works across RT-DETR variants (RT-DETR-L, RT-DETR-X, ResNet-50/101 backbones) without hardcoding layer counts. Parameters are also split into weight, BatchNorm, and bias groups within each section so weight decay is excluded from BatchNorm parameters and biases, matching the default trainer's policy. This is especially useful for RT-DETR fine-tuning, where the decoder head is typically randomly initialized while the backbone carries pretrained features that benefit from a lower learning rate:

python
import torch
from torch import nn

from ultralytics import RTDETR
from ultralytics.models.rtdetr.train import RTDETRTrainer
from ultralytics.utils import LOGGER, colorstr
from ultralytics.utils.torch_utils import unwrap_model


class RTDETRBackboneLRTrainer(RTDETRTrainer):
    """RT-DETR trainer with a lower learning rate for backbone parameters."""

    backbone_lr_ratio = 0.1  # backbone learning rate as a fraction of head learning rate

    def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
        """Build an AdamW optimizer with six param groups: head and backbone x {weight, bn, bias}."""
        # Resolve optimizer name; "auto" maps to AdamW with RT-DETR-style defaults
        canonical = {"Adam", "Adamax", "AdamW", "NAdam", "RAdam", "auto"}
        name = {x.lower(): x for x in canonical}.get(name.lower(), name)
        if name == "auto":
            name, lr, momentum = "AdamW", 1e-4, 0.9
        self.args.warmup_bias_lr = 0.0  # RT-DETR warms biases from 0, unlike YOLO's 0.1
        if name not in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
            raise NotImplementedError(f"This trainer only supports AdamW-family optimizers; got {name}")

        # Identify backbone parameters from model.yaml and route each param into a (section, kind) group
        unwrapped = unwrap_model(model)
        backbone_len = len(unwrapped.yaml["backbone"])
        norm_types = tuple(v for k, v in nn.__dict__.items() if "Norm" in k)
        groups = {f"{s}_{k}": [] for s in ("head", "backbone") for k in ("weight", "bn", "bias")}

        for module_name, module in unwrapped.named_modules():
            for param_name, param in module.named_parameters(recurse=False):
                if not param.requires_grad:
                    continue
                fullname = f"{module_name}.{param_name}" if module_name else param_name
                parts = fullname.split(".")
                section = (
                    "backbone"
                    if len(parts) > 1 and parts[0] == "model" and parts[1].isdigit() and int(parts[1]) < backbone_len
                    else "head"
                )
                if "bias" in param_name:
                    kind = "bias"
                elif isinstance(module, norm_types) or "logit_scale" in fullname:
                    kind = "bn"
                else:
                    kind = "weight"
                groups[f"{section}_{kind}"].append(param)

        # Build the optimizer with per-group lr and weight decay; backbone groups use lr * backbone_lr_ratio
        backbone_lr = lr * self.backbone_lr_ratio
        param_groups = [
            {"params": groups["head_weight"], "lr": lr, "weight_decay": decay, "param_group": "weight"},
            {"params": groups["head_bn"], "lr": lr, "weight_decay": 0.0, "param_group": "bn"},
            {"params": groups["head_bias"], "lr": lr, "weight_decay": 0.0, "param_group": "bias"},
            {"params": groups["backbone_weight"], "lr": backbone_lr, "weight_decay": decay, "param_group": "weight"},
            {"params": groups["backbone_bn"], "lr": backbone_lr, "weight_decay": 0.0, "param_group": "bn"},
            {"params": groups["backbone_bias"], "lr": backbone_lr, "weight_decay": 0.0, "param_group": "bias"},
        ]
        param_groups = [pg for pg in param_groups if pg["params"]]  # drop empty groups
        optimizer = getattr(torch.optim, name)(param_groups, betas=(momentum, 0.999))

        LOGGER.info(
            f"{colorstr('optimizer:')} {name}(lr={lr}, backbone_lr={backbone_lr}) with parameter groups\n"
            f"  Head:     {len(groups['head_bn'])} bn, {len(groups['head_weight'])} weight(decay={decay}), "
            f"{len(groups['head_bias'])} bias (lr={lr})\n"
            f"  Backbone: {len(groups['backbone_bn'])} bn, {len(groups['backbone_weight'])} weight(decay={decay}), "
            f"{len(groups['backbone_bias'])} bias (lr={backbone_lr})"
        )
        return optimizer


model = RTDETR("rtdetr-l.pt")
model.train(data="coco8.yaml", epochs=20, trainer=RTDETRBackboneLRTrainer)

!!! tip "Choosing backbone_lr_ratio"

A common starting point is `backbone_lr_ratio = 0.1`, matching the original RT-DETR setup with its HGNetV2 backbone. The literature suggests scaling the ratio inversely with backbone size and pretraining-data scale: large backbones pretrained on very large datasets (for example ViT-L/H trained with DINO, CLIP, or MAE on hundreds of millions of images) typically use smaller ratios such as `0.01` or below to preserve well-learned features, while smaller backbones with lighter pretraining tolerate larger ratios such as `0.5` or higher.

!!! note "Learning Rate Scheduler"

The built-in learning rate scheduler (`cosine` or `linear`) still applies on top of the per-group base learning rates. Both the backbone and head learning rates will follow the same decay schedule, maintaining the ratio between them throughout training.

!!! tip "Combining Techniques"

These customizations can be combined into a single trainer class by overriding multiple methods and adding callbacks as needed.

Synchronized BatchNorm for Multi-GPU Training

When training on multiple GPUs with DistributedDataParallel, the default BatchNorm2d layers compute statistics independently on each GPU. For RT-DETR fine-tuning and other recipes that use small per-GPU batch sizes, per-GPU batch statistics can be noisy. PyTorch's SyncBatchNorm synchronizes mean and variance across all ranks for a single global batch statistic, which often improves convergence at the cost of a small inter-GPU communication overhead.

The conversion has to happen after the model is on the GPU but before DDP wraps it. The cleanest hook for this is set_model_attributes(), which BaseTrainer calls in exactly that window:

python
from torch import nn

from ultralytics import RTDETR
from ultralytics.models.rtdetr.train import RTDETRTrainer


class SyncBNTrainer(RTDETRTrainer):
    """RT-DETR trainer that converts BatchNorm to SyncBatchNorm for multi-GPU training."""

    def set_model_attributes(self):
        """Run the parent setup, then convert BN to SyncBatchNorm when training on multiple GPUs."""
        super().set_model_attributes()
        if self.world_size > 1:
            self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model)


model = RTDETR("rtdetr-l.pt")
model.train(data="coco8.yaml", epochs=20, device=[0, 1], trainer=SyncBNTrainer)

The world_size > 1 guard ensures the trainer is safe to use in single-GPU runs as well; on a single GPU the conversion is skipped and training proceeds with regular BatchNorm2d. The same pattern works for YOLO by switching the parent class to DetectionTrainer.

!!! tip "When to use SyncBatchNorm"

| Scenario                                       | Recommendation           |
| ---------------------------------------------- | ------------------------ |
| Multi-GPU training, small per-GPU batch (≤ 16) | Enable                   |
| Multi-GPU training, large per-GPU batch (≥ 32) | Optional; minor benefit  |
| Single-GPU training                            | Not applicable (skipped) |

Configurable Gradient Clipping

The default trainer clips gradients to max_norm=10.0 in optimizer_step(), a loose value tuned for YOLO models where gradients rarely exceed it. DETR-family detectors (RT-DETR, DEIM, DINO) typically use much tighter values such as 0.1 to stabilize the decoder's cross-attention layers, where gradient magnitudes can spike. To override the clip value, subclass the trainer and override optimizer_step():

python
import torch

from ultralytics import RTDETR
from ultralytics.models.rtdetr.train import RTDETRTrainer


class CustomClipTrainer(RTDETRTrainer):
    """RT-DETR trainer with configurable gradient clipping."""

    clip_grad_norm = 0.1  # max gradient norm; set to 0 to disable clipping

    def optimizer_step(self):
        """Run an optimizer step with a configurable gradient-norm clip."""
        self.scaler.unscale_(self.optimizer)
        if self.clip_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip_grad_norm)
        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.optimizer.zero_grad()
        if self.ema:
            self.ema.update(self.model)


model = RTDETR("rtdetr-l.pt")
model.train(data="coco8.yaml", epochs=20, trainer=CustomClipTrainer)

The same trainer works for YOLO by switching the parent class to DetectionTrainer (from ultralytics.models.yolo.detect import DetectionTrainer) and loading a YOLO checkpoint with YOLO("yolo26n.pt"). The optimizer_step body is unchanged.

!!! tip "Typical clip_grad_norm values"

| Architecture family          | Typical `max_norm` |
| ---------------------------- | ------------------ |
| RT-DETR / DEIM / DETR family | `0.1`              |
| YOLO (Ultralytics default)   | `10.0`             |
| Disable clipping             | `0`                |

FAQ

How do I pass a custom trainer to YOLO?

Pass your custom trainer class (not an instance) to the trainer parameter in model.train():

python
from ultralytics import YOLO

model = YOLO("yolo26n.pt")
model.train(data="coco8.yaml", trainer=MyCustomTrainer)

The YOLO class handles trainer instantiation internally. See the Advanced Customization page for more details on the trainer architecture.

Which BaseTrainer methods can I override?

Key methods available for customization:

MethodPurpose
validate()Run validation and return metrics
build_optimizer()Construct the optimizer
save_model()Save training checkpoints
get_model()Return the model instance
get_validator()Return the validator instance
get_dataloader()Build the dataloader
preprocess_batch()Preprocess input batch
label_loss_items()Format loss items for logging

For the full API reference, see the BaseTrainer documentation.

Can I use callbacks instead of subclassing the trainer?

Yes, for simpler customizations, callbacks are often sufficient. Available callback events include on_train_start, on_train_epoch_start, on_train_epoch_end, on_fit_epoch_end, and on_model_save. These allow you to hook into the training loop without subclassing. The backbone freezing example above demonstrates this approach.

How do I customize the loss function without subclassing the model?

If your change is simpler (such as adjusting loss gains), you can modify the hyperparameters directly:

python
model.train(data="coco8.yaml", box=10.0, cls=1.5, dfl=2.0)

For structural changes to the loss (such as adding class weights), you need to subclass the loss and model as shown in the class weights section.