docs/neox/utils/trainer.html
1fromtypingimportOptional,Set,List23importtorch.nnasnn4importtorch.optim5importtorch.utils.data6fromtorch.cudaimportamp7fromtorch.cuda.ampimportGradScaler89fromlabmlimportmonit,tracker10fromlabml.configsimportBaseConfigs,option11fromlabml\_nn.neox.utils.finetuneimportFineTuner
model is the model to trainReturns a list of parameters for training
14defget\_trainable\_params(model:nn.Module):
Get all parameters
23params=list(model.parameters())
Filter parameters that require gradients
25trainable\_params=[pforpinparamsifp.requires\_grad]
28returntrainable\_params
31classTrainerConf(BaseConfigs):32model:nn.Module33layers:List[nn.Module]34optimizer:torch.optim.Optimizer='Adam'35train\_loader:torch.utils.data.DataLoader36valid\_loader:Optional[torch.utils.data.DataLoader]=None,37device:torch.device=torch.device('cuda:0')38scaler:Optional[GradScaler]='Default'39is\_amp:bool=True40dtype:torch.dtype=torch.float164142is\_clone\_layers:bool=True4344loss\_func:nn.Module=nn.CrossEntropyLoss()45checkpoints\_per\_epoch:int=046samples\_per\_epoch:int=04748grad\_norm:Optional[float]=1.049learning\_rate:float=3e-450max\_seq\_len:int=102451batch\_size:int=6452epochs:int=165354n\_gpus:int=torch.cuda.device\_count()5556filter\_layers:Optional[Set]=None
dataset_split train/validsample is the sampleReturns the loss, output and the target
58defget\_loss(self,sample,dataset\_split:str):
64data,target=sample
Forward pass
67withmonit.section('Forward pass'):68output=self.model(data.to(self.device))
Move targets to the same device as output
70target=target.to(output.device)
Calculate loss
72loss=self.loss\_func(output.view(target.numel(),-1),target.view(-1))7374returnloss,output,target
76deftrain(self):77forepochinmonit.loop(self.epochs):78self.train\_epoch()79tracker.new\_line()
81defsample(self,idx):82pass
84defsave\_checkpoint(self,idx):85pass
87defget\_iterators(self):
Iterate through the batches
89iterators=[('train',self.train\_loader)]90ifself.valid\_loaderisnotNone:91iterators.append(('valid',self.valid\_loader))9293ifself.samples\_per\_epoch\>0:94iterators.append((self.sample,[iforiinrange(self.samples\_per\_epoch)]))9596ifself.checkpoints\_per\_epoch\>0:97iterators.append((self.save\_checkpoint,[iforiinrange(self.checkpoints\_per\_epoch)]))9899returniterators
101deftrain\_epoch(self):
Set model for train
103self.model.train()104105iterators=self.get\_iterators()106forsplit\_name,sampleinmonit.mix(1024,\*iterators):107ifsplit\_name=='train':
Set gradients to zero
109self.optimizer.zero\_grad()110tracker.add\_global\_step()111112withtorch.set\_grad\_enabled(split\_name=='train'):113ifself.is\_amp:
Forward pass
115withamp.autocast():116loss,output,target=self.get\_loss(sample,split\_name)117else:118loss,output,target=self.get\_loss(sample,split\_name)
Get predictions
121pred=output.argmax(dim=-1)
Calculate accuracy
123accuracy=pred.eq(target).sum().item()/(target!=-100).sum()124125tracker.add({f'loss.{split\_name}':loss,f'acc.{split\_name}':accuracy\*100})126127ifsplit\_name=='train':128ifself.scalerisnotNone:
Backward pass
130loss=self.scaler.scale(loss)
tracker.add({'loss.scaled': loss})
133withmonit.section('Backward pass'):134loss.backward()
Optimize
137withmonit.section('Optimize'):138ifself.scalerisNone:139self.optimizer.step()140else:141self.scaler.unscale\_(self.optimizer)142ifself.grad\_normisnotNone:143torch.nn.utils.clip\_grad\_norm\_(get\_trainable\_params(self.model),self.grad\_norm)144self.scaler.step(self.optimizer)145self.scaler.update()146147tracker.save()
150@option(TrainerConf.optimizer,'Adam')151defadam\_optimizer(c:TrainerConf):152ifc.dtype==torch.float32:153returntorch.optim.Adam(get\_trainable\_params(c.model),lr=c.learning\_rate)154elifc.dtype==torch.float16:155fromlabml\_nn.optimizers.adam\_fp16importAdamFP16156returnAdamFP16(get\_trainable\_params(c.model),lr=c.learning\_rate)157else:158raiseNotImplementedError()159160161@option(TrainerConf.optimizer,'SGD')162defsgd\_optimizer(c:TrainerConf):163returntorch.optim.SGD(get\_trainable\_params(c.model),lr=c.learning\_rate)164165166@option(TrainerConf.scaler,'Default')167defgrad\_scaler(c:TrainerConf):168ifnotc.is\_amp:169returnNone170171ifc.dtype==torch.float16:172fromlabml\_nn.optimizers.adam\_fp16importGradScalerFP16173returnGradScalerFP16()174else:175returnGradScaler()176177178classPipelineParallelTrainerConf(TrainerConf):179is\_checkpointing:bool=False180chunks:int181182fine\_tuner:FineTuner