Back to Pytorch Lightning

Early Stopping

docs/source-pytorch/common/early_stopping.rst

2.6.45.9 KB
Original Source

.. testsetup:: *

from lightning.pytorch.callbacks.early_stopping import EarlyStopping, EarlyStoppingReason
from lightning.pytorch import Trainer, LightningModule

.. _early_stopping:

############## Early Stopping ##############

.. video:: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/yt/Trainer+flags+19-+early+stopping_1.mp4 :poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/yt_thumbs/thumb_earlystop.png :width: 400 :muted:


Stopping an Epoch Early


You can stop and skip the rest of the current epoch early by overriding :meth:~lightning.pytorch.core.hooks.ModelHooks.on_train_batch_start to return -1 when some condition is met.

If you do this repeatedly, for every epoch you had originally requested, then this will stop your entire training.


EarlyStopping Callback


The :class:~lightning.pytorch.callbacks.early_stopping.EarlyStopping callback can be used to monitor a metric and stop the training when no improvement is observed.

To enable it:

  • Import :class:~lightning.pytorch.callbacks.early_stopping.EarlyStopping callback.
  • Log the metric you want to monitor using :meth:~lightning.pytorch.core.LightningModule.log method.
  • Init the callback, and set monitor to the logged metric of your choice.
  • Set the mode based on the metric needs to be monitored.
  • Pass the :class:~lightning.pytorch.callbacks.early_stopping.EarlyStopping callback to the :class:~lightning.pytorch.trainer.trainer.Trainer callbacks flag.

.. code-block:: python

from lightning.pytorch.callbacks.early_stopping import EarlyStopping


class LitModel(LightningModule):
    def validation_step(self, batch, batch_idx):
        loss = ...
        self.log("val_loss", loss)


model = LitModel()
trainer = Trainer(callbacks=[EarlyStopping(monitor="val_loss", mode="min")])
trainer.fit(model)

You can customize the callbacks behaviour by changing its parameters.

.. testcode::

early_stop_callback = EarlyStopping(monitor="val_accuracy", min_delta=0.00, patience=3, verbose=False, mode="max")
trainer = Trainer(callbacks=[early_stop_callback])

Additional parameters that stop training at extreme points:

  • stopping_threshold: Stops training immediately once the monitored quantity reaches this threshold. It is useful when we know that going beyond a certain optimal value does not further benefit us.
  • divergence_threshold: Stops training as soon as the monitored quantity becomes worse than this threshold. When reaching a value this bad, we believes the model cannot recover anymore and it is better to stop early and run with different initial conditions.
  • check_finite: When turned on, it stops training if the monitored metric becomes NaN or infinite.
  • check_on_train_epoch_end: When turned on, it checks the metric at the end of a training epoch. Use this only when you are monitoring any metric logged within training-specific hooks on epoch-level.

After training completes, you can programmatically check why early stopping occurred using the stopping_reason attribute, which returns an EarlyStoppingReason enum value.

.. code-block:: python

from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.callbacks.early_stopping import EarlyStoppingReason

early_stopping = EarlyStopping(monitor="val_loss", patience=3)
trainer = Trainer(callbacks=[early_stopping])
trainer.fit(model)

# Check why training stopped
if early_stopping.stopping_reason == EarlyStoppingReason.PATIENCE_EXHAUSTED:
    print("Training stopped due to patience exhaustion")
elif early_stopping.stopping_reason == EarlyStoppingReason.STOPPING_THRESHOLD:
    print("Training stopped due to reaching stopping threshold")
elif early_stopping.stopping_reason == EarlyStoppingReason.NOT_STOPPED:
    print("Training completed normally without early stopping")

# Access human-readable message
if early_stopping.stopping_reason_message:
    print(f"Details: {early_stopping.stopping_reason_message}")

The available stopping reasons are:

  • NOT_STOPPED: Training completed normally without early stopping
  • STOPPING_THRESHOLD: Training stopped because the monitored metric reached the stopping threshold
  • DIVERGENCE_THRESHOLD: Training stopped because the monitored metric exceeded the divergence threshold
  • PATIENCE_EXHAUSTED: Training stopped because the metric didn't improve for the specified patience
  • NON_FINITE_METRIC: Training stopped because the monitored metric became NaN or infinite

In case you need early stopping in a different part of training, subclass :class:~lightning.pytorch.callbacks.early_stopping.EarlyStopping and change where it is called:

.. testcode::

class MyEarlyStopping(EarlyStopping):
    def on_validation_end(self, trainer, pl_module):
        # override this to disable early stopping at the end of val loop
        pass

    def on_train_end(self, trainer, pl_module):
        # instead, do it at the end of training loop
        self._run_early_stopping_check(trainer)

.. note:: The :class:~lightning.pytorch.callbacks.early_stopping.EarlyStopping callback runs at the end of every validation epoch by default. However, the frequency of validation can be modified by setting various parameters in the :class:~lightning.pytorch.trainer.trainer.Trainer, for example :paramref:~lightning.pytorch.trainer.trainer.Trainer.check_val_every_n_epoch and :paramref:~lightning.pytorch.trainer.trainer.Trainer.val_check_interval. It must be noted that the patience parameter counts the number of validation checks with no improvement, and not the number of training epochs. Therefore, with parameters check_val_every_n_epoch=10 and patience=3, the trainer will perform at least 40 training epochs before being stopped.