docs/source-pytorch/common/early_stopping.rst
.. testsetup:: *
from lightning.pytorch.callbacks.early_stopping import EarlyStopping, EarlyStoppingReason
from lightning.pytorch import Trainer, LightningModule
.. _early_stopping:
############## Early Stopping ##############
.. video:: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/yt/Trainer+flags+19-+early+stopping_1.mp4 :poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/yt_thumbs/thumb_earlystop.png :width: 400 :muted:
Stopping an Epoch Early
You can stop and skip the rest of the current epoch early by overriding :meth:~lightning.pytorch.core.hooks.ModelHooks.on_train_batch_start to return -1 when some condition is met.
If you do this repeatedly, for every epoch you had originally requested, then this will stop your entire training.
EarlyStopping Callback
The :class:~lightning.pytorch.callbacks.early_stopping.EarlyStopping callback can be used to monitor a metric and stop the training when no improvement is observed.
To enable it:
~lightning.pytorch.callbacks.early_stopping.EarlyStopping callback.~lightning.pytorch.core.LightningModule.log method.monitor to the logged metric of your choice.mode based on the metric needs to be monitored.~lightning.pytorch.callbacks.early_stopping.EarlyStopping callback to the :class:~lightning.pytorch.trainer.trainer.Trainer callbacks flag... code-block:: python
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
class LitModel(LightningModule):
def validation_step(self, batch, batch_idx):
loss = ...
self.log("val_loss", loss)
model = LitModel()
trainer = Trainer(callbacks=[EarlyStopping(monitor="val_loss", mode="min")])
trainer.fit(model)
You can customize the callbacks behaviour by changing its parameters.
.. testcode::
early_stop_callback = EarlyStopping(monitor="val_accuracy", min_delta=0.00, patience=3, verbose=False, mode="max")
trainer = Trainer(callbacks=[early_stop_callback])
Additional parameters that stop training at extreme points:
stopping_threshold: Stops training immediately once the monitored quantity reaches this threshold.
It is useful when we know that going beyond a certain optimal value does not further benefit us.divergence_threshold: Stops training as soon as the monitored quantity becomes worse than this threshold.
When reaching a value this bad, we believes the model cannot recover anymore and it is better to stop early and run with different initial conditions.check_finite: When turned on, it stops training if the monitored metric becomes NaN or infinite.check_on_train_epoch_end: When turned on, it checks the metric at the end of a training epoch. Use this only when you are monitoring any metric logged within
training-specific hooks on epoch-level.After training completes, you can programmatically check why early stopping occurred using the stopping_reason
attribute, which returns an EarlyStoppingReason enum value.
.. code-block:: python
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.callbacks.early_stopping import EarlyStoppingReason
early_stopping = EarlyStopping(monitor="val_loss", patience=3)
trainer = Trainer(callbacks=[early_stopping])
trainer.fit(model)
# Check why training stopped
if early_stopping.stopping_reason == EarlyStoppingReason.PATIENCE_EXHAUSTED:
print("Training stopped due to patience exhaustion")
elif early_stopping.stopping_reason == EarlyStoppingReason.STOPPING_THRESHOLD:
print("Training stopped due to reaching stopping threshold")
elif early_stopping.stopping_reason == EarlyStoppingReason.NOT_STOPPED:
print("Training completed normally without early stopping")
# Access human-readable message
if early_stopping.stopping_reason_message:
print(f"Details: {early_stopping.stopping_reason_message}")
The available stopping reasons are:
NOT_STOPPED: Training completed normally without early stoppingSTOPPING_THRESHOLD: Training stopped because the monitored metric reached the stopping thresholdDIVERGENCE_THRESHOLD: Training stopped because the monitored metric exceeded the divergence thresholdPATIENCE_EXHAUSTED: Training stopped because the metric didn't improve for the specified patienceNON_FINITE_METRIC: Training stopped because the monitored metric became NaN or infiniteIn case you need early stopping in a different part of training, subclass :class:~lightning.pytorch.callbacks.early_stopping.EarlyStopping
and change where it is called:
.. testcode::
class MyEarlyStopping(EarlyStopping):
def on_validation_end(self, trainer, pl_module):
# override this to disable early stopping at the end of val loop
pass
def on_train_end(self, trainer, pl_module):
# instead, do it at the end of training loop
self._run_early_stopping_check(trainer)
.. note::
The :class:~lightning.pytorch.callbacks.early_stopping.EarlyStopping callback runs
at the end of every validation epoch by default. However, the frequency of validation
can be modified by setting various parameters in the :class:~lightning.pytorch.trainer.trainer.Trainer,
for example :paramref:~lightning.pytorch.trainer.trainer.Trainer.check_val_every_n_epoch
and :paramref:~lightning.pytorch.trainer.trainer.Trainer.val_check_interval.
It must be noted that the patience parameter counts the number of
validation checks with no improvement, and not the number of training epochs.
Therefore, with parameters check_val_every_n_epoch=10 and patience=3, the trainer
will perform at least 40 training epochs before being stopped.