Back to Transformers

Trainer

docs/source/en/main_classes/trainer.md

5.8.02.4 KB
Original Source
<!--Copyright 2020 The HuggingFace Team. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be rendered properly in your Markdown viewer. -->

Trainer

The [Trainer] class provides an API for feature-complete training in PyTorch, and it supports distributed training on multiple GPUs/TPUs, mixed precision for NVIDIA GPUs, AMD GPUs, and torch.amp for PyTorch. [Trainer] goes hand-in-hand with the [TrainingArguments] class, which offers a wide range of options to customize how a model is trained. Together, these two classes provide a complete training API.

[Seq2SeqTrainer] and [Seq2SeqTrainingArguments] inherit from the [Trainer] and [TrainingArguments] classes and they're adapted for training models for sequence-to-sequence tasks such as summarization or translation.

<Tip warning={true}>

The [Trainer] class is optimized for 🤗 Transformers models and can have surprising behaviors when used with other models. When using it with your own model, make sure:

  • your model always return tuples or subclasses of [~utils.ModelOutput]
  • your model can compute the loss if a labels argument is provided and that loss is returned as the first element of the tuple (if your model returns tuples)
  • your model can accept multiple label arguments (use label_names in [TrainingArguments] to indicate their name to the [Trainer]) but none of them should be named "label"
</Tip>

Trainer[[api-reference]]

[[autodoc]] Trainer - all

Seq2SeqTrainer

[[autodoc]] Seq2SeqTrainer - evaluate - predict

TrainingArguments

[[autodoc]] TrainingArguments - all

Seq2SeqTrainingArguments

[[autodoc]] Seq2SeqTrainingArguments - all