Back to Ultralytics

Train Args

docs/macros/train-args.md

8.4.4620.4 KB
Original Source
ArgumentTypeDefaultDescription
modelstrNoneSpecifies the model file for training. Accepts a path to either a .pt pretrained model or a .yaml configuration file. Essential for defining the model structure or initializing weights.
datastrNonePath to the dataset configuration file (e.g., coco8.yaml). This file contains dataset-specific parameters, including paths to training and validation data, class names, and number of classes.
epochsint100Total number of training epochs. Each epoch represents a full pass over the entire dataset. Adjusting this value can affect training duration and model performance.
timefloatNoneMaximum training time in hours. If set, this overrides the epochs argument, allowing training to automatically stop after the specified duration. Useful for time-constrained training scenarios.
patienceint100Number of epochs to wait without improvement in validation metrics before early stopping the training. Helps prevent overfitting by stopping training when performance plateaus.
batchint or float16Batch size, with three modes: set as an integer (e.g., batch=16), auto mode for 60% GPU memory utilization (batch=-1), or auto mode with specified utilization fraction (batch=0.70).
imgszint640Target image size for training. Images are resized to squares with sides equal to the specified value (if rect=False), preserving aspect ratio for YOLO models but not RT-DETR. Affects model accuracy and computational complexity.
saveboolTrueEnables saving of training checkpoints and final model weights. Useful for resuming training or model deployment.
save_periodint-1Frequency of saving model checkpoints, specified in epochs. A value of -1 disables this feature. Useful for saving interim models during long training sessions.
cacheboolFalseEnables caching of dataset images in memory (True/ram), on disk (disk), or disables it (False). Improves training speed by reducing disk I/O at the cost of increased memory usage.
deviceint or str or listNoneSpecifies the computational device(s) for training: a single GPU (device=0), multiple GPUs (device=[0,1]), CPU (device=cpu), MPS for Apple silicon (device=mps), Huawei Ascend NPU (device=npu or device=npu:0), or auto-selection of most idle GPU (device=-1) or multiple idle GPUs (device=[-1,-1])
workersint8Number of worker threads for data loading (per RANK if Multi-GPU training). Influences the speed of data preprocessing and feeding into the model, especially useful in multi-GPU setups.
projectstrNoneName of the project directory where training outputs are saved. Allows for organized storage of different experiments.
namestrNoneName of the training run. Used for creating a subdirectory within the project folder, where training logs and outputs are stored.
exist_okboolFalseIf True, allows overwriting of an existing project/name directory. Useful for iterative experimentation without needing to manually clear previous outputs.
pretrainedbool or strTrueDetermines whether to start training from pretrained weights. Can be a boolean value or a string path to weights to load. pretrained=False trains from randomly initialized weights while keeping the model architecture.
optimizerstr'auto'Choice of optimizer for training. Options include SGD, MuSGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, or auto for automatic selection based on model configuration. Affects convergence speed and stability.
seedint0Sets the random seed for training, ensuring reproducibility of results across runs with the same configurations.
deterministicboolTrueForces deterministic algorithm use, ensuring reproducibility but may affect performance and speed due to the restriction on non-deterministic algorithms.
verboseboolTrueEnables verbose output during training, displaying progress bars, per-epoch metrics, and additional training information in the console.
single_clsboolFalseTreats all classes in multi-class datasets as a single class during training. Useful for binary classification tasks or when focusing on object presence rather than classification.
classeslist[int]NoneSpecifies a list of class IDs to train on. Useful for filtering out and focusing only on certain classes during training.
rectboolFalseEnables minimum padding strategy—images in a batch are minimally padded to reach a common size, with the longest side equal to imgsz. Can improve efficiency and speed but may affect model accuracy.
multi_scalefloat0.0Randomly vary imgsz each batch by +/- multi_scale (e.g. 0.25 -> 0.75x to 1.25x), rounding to model stride multiples; 0.0 disables multi-scale training.
cos_lrboolFalseUtilizes a cosine learning rate scheduler, adjusting the learning rate following a cosine curve over epochs. Helps in managing learning rate for better convergence.
close_mosaicint10Disables mosaic data augmentation in the last N epochs to stabilize training before completion. Setting to 0 disables this feature.
resumeboolFalseResumes training from the last saved checkpoint. Automatically loads model weights, optimizer state, and epoch count, continuing training seamlessly.
ampboolTrueEnables Automatic Mixed Precision (AMP) training, reducing memory usage and possibly speeding up training with minimal impact on accuracy.
fractionfloat1.0Specifies the fraction of the dataset to use for training. Allows for training on a subset of the full dataset, useful for experiments or when resources are limited.
profileboolFalseEnables profiling of ONNX and TensorRT speeds during training, useful for optimizing model deployment.
freezeint or listNoneFreezes the first N layers of the model or specified layers by index, reducing the number of trainable parameters. Useful for fine-tuning or transfer learning.
lr0float0.01Initial learning rate (i.e. SGD=1E-2, Adam=1E-3). Adjusting this value is crucial for the optimization process, influencing how rapidly model weights are updated.
lrffloat0.01Final learning rate as a fraction of the initial rate = (lr0 * lrf), used in conjunction with schedulers to adjust the learning rate over time.
momentumfloat0.937Momentum factor for SGD or beta1 for Adam optimizers, influencing the incorporation of past gradients in the current update.
weight_decayfloat0.0005L2 regularization term, penalizing large weights to prevent overfitting.
warmup_epochsfloat3.0Number of epochs for learning rate warmup, gradually increasing the learning rate from a low value to the initial learning rate to stabilize training early on.
warmup_momentumfloat0.8Initial momentum for warmup phase, gradually adjusting to the set momentum over the warmup period.
warmup_bias_lrfloat0.1Learning rate for bias parameters during the warmup phase, helping stabilize model training in the initial epochs.
boxfloat7.5Weight of the box loss component in the loss function, influencing how much emphasis is placed on accurately predicting bounding box coordinates.
clsfloat0.5Weight of the classification loss in the total loss function, affecting the importance of correct class prediction relative to other components.
cls_pwfloat0.0Power for class weighting to handle class imbalance using inverse class frequency. 0.0 disables class weighting, 1.0 applies full inverse frequency weighting. Values between 0 and 1 provide partial weighting.
dflfloat1.5Weight of the distribution focal loss, used in certain YOLO versions for fine-grained classification.
posefloat12.0Weight of the pose loss in models trained for pose estimation, influencing the emphasis on accurately predicting pose keypoints.
kobjfloat1.0Weight of the keypoint objectness loss in pose estimation models, balancing detection confidence with pose accuracy.
rlefloat1.0Weight of the residual log-likelihood estimation loss in pose estimation models, affecting the precision of keypoint localization.
anglefloat1.0Weight of the angle loss in obb models, affecting the precision of oriented bounding box angle predictions.
nbsint64Nominal batch size for normalization of loss.
overlap_maskboolTrueDetermines whether object masks should be merged into a single mask for training, or kept separate for each object. In case of overlap, the smaller mask is overlaid on top of the larger mask during merge.
mask_ratioint4Downsample ratio for segmentation masks, affecting the resolution of masks used during training.
dropoutfloat0.0Dropout rate for regularization in classification tasks, preventing overfitting by randomly omitting units during training.
valboolTrueEnables validation during training, allowing for periodic evaluation of model performance on a separate dataset.
plotsboolTrueGenerates and saves plots of training and validation metrics, as well as prediction examples, providing visual insights into model performance and learning progression.
compilebool or strFalseEnables PyTorch 2.x torch.compile graph compilation with backend='inductor'. Accepts True"default", False → disables, or a string mode such as "default", "reduce-overhead", "max-autotune-no-cudagraphs". Falls back to eager with a warning if unsupported.
max_detint300Specifies the maximum number of objects retained during validation phase of training.