docs/source/elastic/train_script.rst
.. _elastic_train_script:
If your train script works with torch.distributed.launch it will continue
working with torchrun with these differences:
No need to manually pass RANK, WORLD_SIZE,
MASTER_ADDR, and MASTER_PORT.
rdzv_backend and rdzv_endpoint can be provided. For most users
this will be set to c10d (see rendezvous <rendezvous.html>_). The default
rdzv_backend creates a non-elastic rendezvous where rdzv_endpoint holds
the master address.
Make sure you have a load_checkpoint(path) and
save_checkpoint(path) logic in your script. When any number of
workers fail we restart all the workers with the same program
arguments so you will lose progress up to the most recent checkpoint
(see elastic launch <run.html>_).
use_env flag has been removed. If you were parsing local rank by parsing
the --local-rank option, you need to get the local rank from the
environment variable LOCAL_RANK (e.g. int(os.environ["LOCAL_RANK"])).
Below is an expository example of a training script that checkpoints on each epoch, hence the worst-case progress lost on failure is one full epoch worth of training.
.. code-block:: python
def main(): args = parse_args(sys.argv[1:]) state = load_checkpoint(args.checkpoint_path) initialize(state)
# torch.distributed.run ensures that this will work
# by exporting all the env vars needed to initialize the process group
torch.distributed.init_process_group(backend=args.backend)
for i in range(state.epoch, state.total_num_epochs)
for batch in iter(state.dataset)
train(batch, state.model)
state.epoch += 1
save_checkpoint(state)
For concrete examples of torchelastic-compliant train scripts, visit
our examples <examples.html>_ page.