docs/source-pytorch/extensions/strategy.rst
################### What is a Strategy? ###################
Strategy controls the model distribution across training, evaluation, and prediction to be used by the :doc:Trainer <../common/trainer>. It can be controlled by passing different
strategy with aliases ("ddp", "ddp_spawn", "deepspeed" and so on) as well as a custom strategy to the strategy parameter for Trainer.
The Strategy in PyTorch Lightning handles the following responsibilities:
~lightning.pytorch.core.LightningModuleStrategy is a composition of one :doc:Accelerator <../extensions/accelerator>, one :ref:Precision Plugin <extensions/plugins:Precision Plugins>, a :ref:CheckpointIO <extensions/plugins:CheckpointIO Plugins>
plugin and other optional plugins such as the :ref:ClusterEnvironment <extensions/plugins:Cluster Environments>.
.. image:: https://pl-public-data.s3.amazonaws.com/docs/static/images/strategies/overview.jpeg :alt: Illustration of the Strategy as a composition of the Accelerator and several plugins
We expose Strategies mainly for expert users that want to extend Lightning for new hardware support or new distributed backends (e.g. a backend not yet supported by PyTorch <https://pytorch.org/docs/stable/distributed.html#backends>_ itself).
Selecting a Built-in Strategy
Built-in strategies can be selected in two ways.
strategy Trainer argumentlightning.pytorch.strategies, instantiate it and pass it to the strategy Trainer argumentThe latter allows you to configure further options on the specific strategy. Here are some examples:
.. code-block:: python
# Training with the DistributedDataParallel strategy on 4 GPUs
trainer = Trainer(strategy="ddp", accelerator="gpu", devices=4)
# Training with the DistributedDataParallel strategy on 4 GPUs, with options configured
trainer = Trainer(strategy=DDPStrategy(static_graph=True), accelerator="gpu", devices=4)
# Training with the DDP Spawn strategy using auto accelerator selection
trainer = Trainer(strategy="ddp_spawn", accelerator="auto", devices=4)
# Training with the DeepSpeed strategy on available GPUs
trainer = Trainer(strategy="deepspeed", accelerator="gpu", devices="auto")
# Training with the DDP strategy using 3 CPU processes
trainer = Trainer(strategy="ddp", accelerator="cpu", devices=3)
# Training with the DDP Spawn strategy on 8 TPU cores
trainer = Trainer(strategy="ddp_spawn", accelerator="tpu", devices=8)
The below table lists all relevant strategies available in Lightning with their corresponding short-hand name:
.. list-table:: Strategy Classes and Nicknames :widths: 20 20 20 :header-rows: 1
~lightning.pytorch.strategies.FSDPStrategyLearn more. <../advanced/model_parallel/fsdp>~lightning.pytorch.strategies.DDPStrategyLearn more. <accelerators/gpu_intermediate:Distributed Data Parallel>~lightning.pytorch.strategies.DDPStrategytorch.multiprocessing.spawn method and joins processes after training finishes. :ref:Learn more. <accelerators/gpu_intermediate:Distributed Data Parallel Spawn>~lightning.pytorch.strategies.DeepSpeedStrategyLearn more. <../advanced/model_parallel/deepspeed>~lightning.pytorch.strategies.XLAStrategytorch_xla.distributed.xla_multiprocessing.spawn method. :doc:Learn more. <../accelerators/tpu>~lightning.pytorch.strategies.SingleXLAStrategyLearn more. <../accelerators/tpu>Third-party Strategies
There are powerful third-party strategies that integrate well with Lightning but aren't maintained as part of the lightning package.
Checkout the gallery over :doc:here <../integrations/strategies/index>.
Create a Custom Strategy
Every strategy in Lightning is a subclass of one of the main base classes: :class:~lightning.pytorch.strategies.Strategy, :class:~lightning.pytorch.strategies.SingleDeviceStrategy or :class:~lightning.pytorch.strategies.ParallelStrategy.
.. image:: https://pl-public-data.s3.amazonaws.com/docs/static/images/strategies/hierarchy.jpeg :alt: Strategy base classes
As an expert user, you may choose to extend either an existing built-in Strategy or create a completely new one by subclassing the base classes.
.. code-block:: python
from lightning.pytorch.strategies import DDPStrategy
class CustomDDPStrategy(DDPStrategy):
def configure_ddp(self):
self.model = MyCustomDistributedDataParallel(
self.model,
device_ids=...,
)
def setup(self, trainer):
# you can access the accelerator and plugins directly
self.accelerator.setup()
self.precision_plugin.connect(...)
The custom strategy can then be passed into the Trainer directly via the strategy parameter.
.. code-block:: python
# custom strategy
trainer = Trainer(strategy=CustomDDPStrategy())
Since the strategy also hosts the Accelerator and various plugins, you can customize all of them to work together as you like:
.. code-block:: python
# custom strategy, with new accelerator and plugins
accelerator = MyAccelerator()
precision_plugin = MyPrecisionPlugin()
strategy = CustomDDPStrategy(accelerator=accelerator, precision_plugin=precision_plugin)
trainer = Trainer(strategy=strategy)