docs/scaling/zero3/finetune_neox.html
This script trains the bias parameters of the GPT-NeoX model on multiple devices with Zero-DP Memory Optimization.
14importdatetime1516importtorch17importtorch.distributed1819fromlabmlimportexperiment,monit,tracker20fromlabml.configsimportoption21fromlabml.loggerimportinspect22fromlabml\_nn.neox.samples.finetuneimportPipelineParallelTrainerConf
Use the Pipeline Parallel Trainer configurations and adapt it for Zero3 memory optimizer.
27classConfigs(PipelineParallelTrainerConf):28rank:int29world\_size:int
Note that we pass the sharded parameters from get_trainable_chunk .
32@option(Configs.optimizer,'Zero3Adam')33def\_optimizer(c:Configs):
39fromlabml\_nn.optimizers.adam\_fp16importAdamFP1640returnAdamFP16(c.model.get\_trainable\_chunk(),lr=c.learning\_rate)
43@option(Configs.model,'Zero3')44def\_model(c:Configs):
48fromlabml\_nn.scaling.zero3importZero3Layer,Zero3Sequential
To make sure the fine tuner sets the trainable parameters
51\_=c.fine\_tuner
Wrap the layers with Zero3Layer
54modules=[]55forminmonit.iterate('Zero3',c.layers):56modules.append(Zero3Layer(m.to(c.device),57c.rank,c.world\_size,c.device,c.dtype))
Create a sequential model
60model=Zero3Sequential(modules)
63returnmodel
rank .66defmain(rank:int,world\_size:int,init\_method:str='tcp://localhost:23456'):
Initialize PyTorch distributed process group
71withmonit.section('Distributed'):72torch.distributed.init\_process\_group('nccl',73timeout=datetime.timedelta(seconds=30),74init\_method=init\_method,75rank=rank,76world\_size=world\_size)
Set current device
79device=torch.device(f'cuda:{rank}')80torch.cuda.set\_device(device)
Create the experiment
83experiment.create(name='zero3\_neox',writers={'screen','labml'},84distributed\_world\_size=world\_size,85distributed\_rank=rank)
Create configurations
88conf=Configs()
Load configurations
91experiment.configs(conf,{92'model':'Zero3',93'optimizer':'Zero3Adam',9495'device':device,96'rank':rank,97'world\_size':world\_size,9899'learning\_rate':3e-4,100'max\_seq\_len':128,101'batch\_size':16,102})
Start the experiment
105withexperiment.start():
Initialize the model. Do this before the loop for cleaner logs.
107\_=conf.model
Train the model
110forepochinmonit.loop(conf.epochs):111conf.train\_epoch()112tracker.new\_line()
116if\_\_name\_\_=='\_\_main\_\_':
Log the machine configurations
118inspect([torch.cuda.get\_device\_name(i)foriinrange(torch.cuda.device\_count())])119inspect(120n\_gpus=torch.cuda.device\_count(),121mpi=torch.distributed.is\_mpi\_available(),122nccl=torch.distributed.is\_nccl\_available(),123)124125n\_gpu=torch.cuda.device\_count()
Start a process for each GPU. You will need a separate launcher if you are using multiple computers.
128torch.multiprocessing.spawn(main,args=(n\_gpu,),nprocs=n\_gpu,join=True)