docs/source-pytorch/common/hooks.rst
########################## Hooks in PyTorch Lightning ##########################
Hooks in Pytorch Lightning allow you to customize the training, validation, and testing logic of your models. They provide a way to insert custom behavior at specific points during the training process without modifying the core training loop. There are several categories of hooks available in PyTorch Lightning:
Nearly all hooks can be implemented in three places within your code:
Importantly, because logic can be place in the same hook but in different places the call order of hooks is in important to understand. The following order is always used:
LightningModule.. testcode::
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.demos import BoringModel
class MyModel(BoringModel):
def on_train_start(self):
print("Model: Training is starting!")
class MyCallback(Callback):
def on_train_start(self, trainer, pl_module):
print("Callback: Training is starting!")
model = MyModel()
callback = MyCallback()
trainer = Trainer(callbacks=[callback], logger=False, max_epochs=1)
trainer.fit(model)
.. testoutput:: :hide: :options: +ELLIPSIS, +NORMALIZE_WHITESPACE
┏━━━┳━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━┓
┃ ┃ Name ┃ Type ┃ Params ┃ Mode ┃ FLOPs ┃
┡━━━╇━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━┩
│ 0 │ layer │ Linear │ 66 │ train │ 0 │
└───┴───────┴────────┴────────┴───────┴───────┘
...
Callback: Training is starting!
Model: Training is starting!
Epoch 0/0 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 64/64 ...
.. note:: There are a few exceptions to this pattern:
LightningModule, then monitoring callbacksLightningModule are calledLightningModule or StrategyTraining Loop Hook Order
The following diagram shows the execution order of hooks during a typical training loop e.g. calling trainer.fit(),
with the source of each hook indicated:
.. code-block:: text
Training Process Flow:
trainer.fit()
│
├── setup(stage="fit")
│ ├── [LightningDataModule]
│ ├── [Callbacks]
│ ├── [LightningModule]
│ ├── [LightningModule.configure_shared_model()]
│ ├── [LightningModule.configure_model()]
│ ├── Strategy.restore_checkpoint_before_setup
│ │ ├── [LightningModule.on_load_checkpoint()]
│ │ ├── [LightningModule.load_state_dict()]
│ │ ├── [LightningDataModule.load_state_dict()]
│ │ ├── [Callbacks.on_load_checkpoint()]
│ │ └── [Callbacks.load_state_dict()]
│ └── [Strategy]
│
├── on_fit_start()
│ ├── [Callbacks]
│ └── [LightningModule]
│
├── Strategy.restore_checkpoint_after_setup
│ ├── [LightningModule.on_load_checkpoint()]
│ ├── [LightningModule.load_state_dict()]
│ ├── [LightningDataModule.load_state_dict()]
│ ├── [Callbacks.on_load_checkpoint()]
│ └── [Callbacks.load_state_dict()]
│
├── on_sanity_check_start()
│ ├── [Callbacks]
│ ├── [LightningModule]
│ └── [Strategy]
│ ├── on_validation_start()
│ │ ├── [Callbacks]
│ │ ├── [LightningModule]
│ │ └── [Strategy]
│ ├── on_validation_epoch_start()
│ │ ├── [Callbacks]
│ │ ├── [LightningModule]
│ │ └── [Strategy]
│ │ ├── [for each validation batch]
│ │ │ ├── on_validation_batch_start()
│ │ │ │ ├── [Callbacks]
│ │ │ │ ├── [LightningModule]
│ │ │ │ └── [Strategy]
│ │ │ └── on_validation_batch_end()
│ │ │ ├── [Callbacks]
│ │ │ ├── [LightningModule]
│ │ │ └── [Strategy]
│ │ └── [end validation batches]
│ ├── on_validation_epoch_end()
│ │ ├── [Callbacks]
│ │ ├── [LightningModule]
│ │ └── [Strategy]
│ └── on_validation_end()
│ ├── [Callbacks]
│ ├── [LightningModule]
│ └── [Strategy]
├── on_sanity_check_end()
│ ├── [Callbacks]
│ ├── [LightningModule]
│ └── [Strategy]
│
├── on_train_start()
│ ├── [Callbacks]
│ ├── [LightningModule]
│ └── [Strategy]
│
├── [Training Epochs Loop]
│ │
│ ├── on_train_epoch_start()
│ │ ├── [Callbacks]
│ │ └── [LightningModule]
│ │
│ ├── [Training Batches Loop]
│ │ │
│ │ ├── on_train_batch_start()
│ │ │ ├── [Callbacks]
│ │ │ ├── [LightningModule]
│ │ │ └── [Strategy]
│ │ │
│ │ ├── [Forward Pass - training_step()]
│ │ │ └── [Strategy only]
│ │ │
│ │ ├── on_before_zero_grad()
│ │ │ ├── [Callbacks]
│ │ │ └── [LightningModule]
│ │ │
│ │ ├── optimizer_zero_grad()
│ │ │ └── [LightningModule only - optimizer_zero_grad()]
│ │ │
│ │ ├── [Backward Pass - Strategy.backward()]
│ │ │ ├── on_before_backward()
│ │ │ │ ├── [Callbacks]
│ │ │ │ └── [LightningModule]
│ │ │ ├── LightningModule.backward()
│ │ │ └── on_after_backward()
│ │ │ ├── [Callbacks]
│ │ │ └── [LightningModule]
│ │ │
│ │ ├── on_before_optimizer_step()
│ │ │ ├── [Callbacks]
│ │ │ └── [LightningModule]
│ │ │
│ │ ├── [Optimizer Step]
│ │ │ └── [LightningModule only - optimizer_step()]
│ │ │
│ │ └── on_train_batch_end()
│ │ ├── [Callbacks]
│ │ └── [LightningModule]
│ │
│ │ [Optional: Validation during training]
│ │ ├── on_validation_start()
│ │ │ ├── [Callbacks]
│ │ │ ├── [LightningModule]
│ │ │ └── [Strategy]
│ │ ├── on_validation_epoch_start()
│ │ │ ├── [Callbacks]
│ │ │ ├── [LightningModule]
│ │ │ └── [Strategy]
│ │ │ ├── [for each validation batch]
│ │ │ │ ├── on_validation_batch_start()
│ │ │ │ │ ├── [Callbacks]
│ │ │ │ │ ├── [LightningModule]
│ │ │ │ │ └── [Strategy]
│ │ │ │ └── on_validation_batch_end()
│ │ │ │ ├── [Callbacks]
│ │ │ │ ├── [LightningModule]
│ │ │ │ └── [Strategy]
│ │ │ └── [end validation batches]
│ │ ├── on_validation_epoch_end()
│ │ │ ├── [Callbacks]
│ │ │ ├── [LightningModule]
│ │ │ └── [Strategy]
│ │ └── on_validation_end()
│ │ ├── [Callbacks]
│ │ ├── [LightningModule]
│ │ └── [Strategy]
│ │
│ └── on_train_epoch_end() **SPECIAL CASE**
│ ├── [Callbacks - Non-monitoring only]
│ ├── [LightningModule]
│ └── [Callbacks - Monitoring only]
│
├── [End Training Epochs]
│
├── on_train_end()
│ ├── [Callbacks]
│ ├── [LightningModule]
│ └── [Strategy]
│
└── teardown(stage="fit")
├── [Strategy]
├── on_fit_end()
│ ├── [Callbacks]
│ └── [LightningModule]
├── [LightningDataModule]
├── [Callbacks]
└── [LightningModule]
Testing Loop Hook Order
When running tests with trainer.test():
.. code-block:: text
trainer.test()
│
├── setup(stage="test")
│ └── [Callbacks only]
├── on_test_start()
│ ├── [Callbacks]
│ ├── [LightningModule]
│ └── [Strategy]
│
├── [Test Epochs Loop]
│ │
│ ├── on_test_epoch_start()
│ │ ├── [Callbacks]
│ │ ├── [LightningModule]
│ │ └── [Strategy]
│ │
│ ├── [Test Batches Loop]
│ │ │
│ │ ├── on_test_batch_start()
│ │ │ ├── [Callbacks]
│ │ │ ├── [LightningModule]
│ │ │ └── [Strategy]
│ │ │
│ │ └── on_test_batch_end()
│ │ ├── [Callbacks]
│ │ ├── [LightningModule]
│ │ └── [Strategy]
│ │
│ └── on_test_epoch_end()
│ ├── [Callbacks]
│ ├── [LightningModule]
│ └── [Strategy]
│
├── on_test_end()
│ ├── [Callbacks]
│ ├── [LightningModule]
│ └── [Strategy]
└── teardown(stage="test")
└── [Callbacks only]
Prediction Loop Hook Order
When running predictions with trainer.predict():
.. code-block:: text
trainer.predict()
│
├── setup(stage="predict")
│ └── [Callbacks only]
├── on_predict_start()
│ ├── [Callbacks]
│ ├── [LightningModule]
│ └── [Strategy]
│
├── [Prediction Epochs Loop]
│ │
│ ├── on_predict_epoch_start()
│ │ ├── [Callbacks]
│ │ └── [LightningModule]
│ │
│ ├── [Prediction Batches Loop]
│ │ │
│ │ ├── on_predict_batch_start()
│ │ │ ├── [Callbacks]
│ │ │ └── [LightningModule]
│ │ │
│ │ └── on_predict_batch_end()
│ │ ├── [Callbacks]
│ │ └── [LightningModule]
│ │
│ └── on_predict_epoch_end()
│ ├── [Callbacks]
│ └── [LightningModule]
│
├── on_predict_end()
│ ├── [Callbacks]
│ ├── [LightningModule]
│ └── [Strategy]
└── teardown(stage="predict")
└── [Callbacks only]