Back to Annotated Deep Learning Paper Implementations

trainer.py

docs/helpers/trainer.html

latest13.7 KB
Original Source

homehelpers

View code on Github

#

1importsignal2importtyping3fromtypingimportDict,List,Callable4fromtypingimportOptional,Tuple,Any,Collection56importtorch.optim7importtorch.optim8importtorch.utils.data9importtorch.utils.data10fromlabmlimporttracker,logger,monit11fromlabml.configsimportBaseConfigs,meta\_config,option12fromlabml.internal.monitorimportLoop13fromlabml.loggerimportText14fromtorchimportnn15from.deviceimportDeviceConfigs16from.metricsimportStateModule

#

19classTrainingLoopIterator(Collection):

#

20def\_\_init\_\_(self,start:int,total:int,step:Optional[int]):21self.step=step22self.total=total23self.start=start24self.i=None

#

26def\_\_iter\_\_(self):27self.i=None28returnself

#

30def\_\_next\_\_(self):31ifself.stepisnotNone:32ifself.iisNone:33self.i=self.start34else:35self.i+=self.step36else:37ifself.iisNone:38self.i=039else:40self.i+=14142ifself.i\>=self.total:43raiseStopIteration()4445ifself.stepisNone:46returntracker.get\_global\_step()47else:48returnself.i

#

50def\_\_len\_\_(self)-\>int:51ifself.stepisnotNone:52return(self.total-self.start)//self.step53else:54returnself.total

#

56def\_\_contains\_\_(self,x:object)-\>bool:57returnFalse

#

60classTrainingLoop:61\_iter:Optional[TrainingLoopIterator]62\_\_loop:Loop63\_\_signal\_received:Optional[Tuple[Any,Any]]

#

65def\_\_init\_\_(self,\*,66loop\_count:int,67loop\_step:Optional[int],68log\_new\_line\_interval:int,69log\_write\_interval:int,70is\_loop\_on\_interrupt:bool):71self.\_\_loop\_count=loop\_count72self.\_\_loop\_step=loop\_step73self.\_\_log\_new\_line\_interval=log\_new\_line\_interval74self.\_\_log\_write\_interval=log\_write\_interval75self.\_\_last\_write\_step=076self.\_\_last\_new\_line\_step=077self.\_\_last\_save\_step=078self.\_\_signal\_received=None79self.\_\_is\_loop\_on\_interrupt=is\_loop\_on\_interrupt80self.\_iter=None

#

82def\_\_iter\_\_(self):83self.\_iter=TrainingLoopIterator(tracker.get\_global\_step(),84self.\_\_loop\_count,85self.\_\_loop\_step)8687self.\_\_loop=monit.loop(typing.cast(Collection,self.\_iter))8889iter(self.\_\_loop)90try:91self.old\_handler=signal.signal(signal.SIGINT,self.\_\_handler)92exceptValueError:93pass94returnself

#

96@property97defidx(self):98ifnotself.\_iter:99return0100ifnotself.\_iter.i:101return0102ifself.\_\_loop\_stepisNone:103returnself.\_iter.i104returnself.\_iter.i/self.\_\_loop\_step105106def\_\_finish(self):107try:108signal.signal(signal.SIGINT,self.old\_handler)109exceptValueError:110pass111tracker.save()112tracker.new\_line()113114def\_\_next\_\_(self):115ifself.\_\_signal\_receivedisnotNone:116logger.log('\nKilling Loop.',Text.danger)117monit.finish\_loop()118self.\_\_finish()119raiseStopIteration("SIGINT")120121try:122global\_step=next(self.\_\_loop)123exceptStopIterationase:124self.\_\_finish()125raisee126127tracker.set\_global\_step(global\_step)128129ifglobal\_step-self.\_\_last\_write\_step\>=self.\_\_log\_write\_interval:130tracker.save()131self.\_\_last\_write\_step=global\_step132ifglobal\_step-self.\_\_last\_new\_line\_step\>=self.\_\_log\_new\_line\_interval:133tracker.new\_line()134self.\_\_last\_new\_line\_step=global\_step135136returnglobal\_step137138def\_\_handler(self,sig,frame):

#

Pass second interrupt without delaying

140ifself.\_\_signal\_receivedisnotNone:141logger.log('\nSIGINT received twice. Stopping...',Text.danger)142self.old\_handler(\*self.\_\_signal\_received)143return144145ifself.\_\_is\_loop\_on\_interrupt:

#

Store the interrupt signal for later

147self.\_\_signal\_received=(sig,frame)148logger.log('\nSIGINT received. Delaying KeyboardInterrupt.',Text.danger)149else:150self.\_\_finish()151logger.log('Killing loop...',Text.danger)152self.old\_handler(sig,frame)

#

154def\_\_str\_\_(self):155return"LabTrainingLoop"

#

This is a configurable training loop. You can extend this class for your configurations if it involves a training loop.

>>> for step in conf.training_loop: >>> ...

Arguments: loop_count (int): Total number of steps. Defaults to 10 . loop_step (int): Number of steps to increment per iteration. Defaults to 1 . log_new_line_interval (int): The interval (in steps) to print a new line to the screen. Defaults to 1 . log_write_interval (int): The interval (in steps) to call :func:labml.tracker.save . Defaults to 1 . is_loop_on_interrupt (bool): Whether to handle keyboard interrupts and wait until a iteration is complete. Defaults to False .

158classTrainingLoopConfigs(BaseConfigs):

#

176loop\_count:int=10177loop\_step:int=1178log\_new\_line\_interval:int=1179log\_write\_interval:int=1180is\_loop\_on\_interrupt:bool=False181182training\_loop:TrainingLoop

#

185@option(TrainingLoopConfigs.training\_loop)186def\_loop\_configs(c:TrainingLoopConfigs):187returnTrainingLoop(loop\_count=c.loop\_count,188loop\_step=c.loop\_step,189log\_new\_line\_interval=c.log\_new\_line\_interval,190log\_write\_interval=c.log\_write\_interval,191is\_loop\_on\_interrupt=c.is\_loop\_on\_interrupt)192193194meta\_config(TrainingLoopConfigs.loop\_step,195TrainingLoopConfigs.loop\_count,196TrainingLoopConfigs.log\_new\_line\_interval,197TrainingLoopConfigs.log\_write\_interval,198TrainingLoopConfigs.is\_loop\_on\_interrupt)199200201classModeState:202def\_\_init\_\_(self):203self.\_rollback\_stack=[]204205self.is\_train=False206self.is\_optimize=False207208def\_enter(self,mode:Dict[str,any]):209rollback={}210fork,vinmode.items():211ifvisNone:212continue213rollback[k]=getattr(self,k)214setattr(self,k,v)215216self.\_rollback\_stack.append(rollback)217218returnlen(self.\_rollback\_stack)219220def\_exit(self,n:int):221assertn==len(self.\_rollback\_stack)222223rollback=self.\_rollback\_stack[-1]224self.\_rollback\_stack.pop(-1)225226fork,vinrollback.items():227setattr(self,k,v)228229defupdate(self,\*,230is\_train:Optional[bool]=None,231is\_optimize:Optional[bool]=None):232returnMode(self,233is\_train=is\_train,234is\_optimize=is\_optimize)235236237classMode:238def\_\_init\_\_(self,mode:ModeState,\*\*kwargs:any):239self.mode=mode240self.update={}241fork,vinkwargs.items():242ifvisnotNone:243self.update[k]=v244245self.idx=-1246247def\_\_enter\_\_(self):248self.idx=self.mode.\_enter(self.update)249250def\_\_exit\_\_(self,exc\_type,exc\_val,exc\_tb):251self.mode.\_exit(self.idx)252253254classTrainer:255def\_\_init\_\_(self,\*,256name:str,257mode:ModeState,258data\_loader:torch.utils.data.DataLoader,259inner\_iterations:int,260state\_modules:List[StateModule],261is\_track\_time:bool,262step:Callable[[any,'BatchIndex'],None]):263self.is\_track\_time=is\_track\_time264self.mode=mode265self.name=name266self.step=step267self.state\_modules=state\_modules268self.\_\_iterable=None269self.\_\_states=[sm.create\_state()forsminself.state\_modules]270self.inner\_iterations=inner\_iterations271self.data\_loader=data\_loader272self.\_batch\_index=BatchIndex(len(self.data\_loader),self.inner\_iterations)273274defset\_data\_loader(self,data\_loader:torch.utils.data.DataLoader):275self.data\_loader=data\_loader276self.\_batch\_index=BatchIndex(len(data\_loader),self.inner\_iterations)277self.\_\_iterable=None278279def\_\_call\_\_(self):280forsm,sinzip(self.state\_modules,self.\_\_states):281sm.set\_state(s)282283ifself.\_\_iterableisNoneorself.\_batch\_index.completed:284self.\_\_iterable=iter(self.data\_loader)285self.\_batch\_index.reset(len(self.data\_loader),self.inner\_iterations)286forsminself.state\_modules:287sm.on\_epoch\_start()288withtorch.set\_grad\_enabled(self.mode.is\_train):289self.\_\_iterate()290291ifself.\_batch\_index.completed:292forsminself.state\_modules:293sm.on\_epoch\_end()294295def\_\_iterate(self):296withmonit.section(self.name,is\_partial=True,is\_track=self.is\_track\_time):297ifself.\_batch\_index.idx==0:298monit.progress(0)299whilenotself.\_batch\_index.iteration\_completed:300batch=next(self.\_\_iterable)301302self.step(batch,self.\_batch\_index)303304self.\_batch\_index.step()305monit.progress(self.\_batch\_index.epoch\_progress)306307self.\_batch\_index.step\_inner()308309310classBatchIndex:311idx:int312total:int313iteration:int314total\_iterations:int315316def\_\_init\_\_(self,total:int,total\_iterations:int):317self.total\_iterations=total\_iterations318self.total=total319320defis\_interval(self,interval:int):321ifinterval\<=0:322returnFalse323ifself.idx+1==self.total:324returnTrue325else:326return(self.idx+1)%interval==0327328@property329defis\_last(self):330returnself.idx+1==self.total331332@property333defcompleted(self):334returnself.iteration\>=self.total\_iterations335336@property337defiteration\_completed(self):

#

// is important so that the last step happens on the last iteration

339returnself.idx\>=(self.iteration+1)\*self.total//self.total\_iterations

#

This is a configurable module that you can extend for experiments that involve a training and validation datasets (i.e. most DL experiments).

Arguments: epochs (int): Number of epochs to train on. Defaults to 10 . train_loader (torch.utils.data.DataLoader): Training data loader. valid_loader (torch.utils.data.DataLoader): Training data loader. inner_iterations (int): Number of times to switch between training and validation within an epoch. Defaults to 1 .

You can override init , step functions. There is also a sample function that you can override to generate samples ever time it switches between training and validation.

341@property342defepoch\_progress(self):343returnself.idx/self.total344345defstep(self):346self.idx+=1347348defstep\_inner(self):349self.iteration+=1350351defreset(self,total:int,total\_iterations:int):352self.total=total353self.total\_iterations=total\_iterations354self.idx=0355self.iteration=0356357358classTrainValidConfigs(TrainingLoopConfigs):

#

373state\_modules:List[StateModule]374375mode:ModeState376377epochs:int=10378379trainer:Trainer380validator:Trainer381train\_loader:torch.utils.data.DataLoader382valid\_loader:torch.utils.data.DataLoader383384loop\_count='\_data\_loop\_count'385loop\_step=None386387inner\_iterations:int=1388389is\_track\_time:bool=False

#

391definit(self):392pass

#

394defstep(self,batch:Any,batch\_idx:BatchIndex):395raiseNotImplementedError

#

397defrun\_step(self):398foriinrange(self.inner\_iterations):399withtracker.namespace('sample'):400self.sample()401withself.mode.update(is\_train=True):402withtracker.namespace('train'):403self.trainer()404ifself.validator:405withtracker.namespace('valid'):406self.validator()407tracker.save()

#

409defrun(self):410withmonit.section("Initialize"):411self.init()412\_=self.validator413\_=self.trainer414for\_inself.training\_loop:415self.run\_step()

#

417defsample(self):418pass

#

This is a configurable module that works for many standard DL experiments.

Arguments: model: A PyTorch model. optimizer: A PyTorch optimizer to update model. device: The device to train the model on. This defaults to a configurable device loss_function: A function to calculate the loss. This should accept model\_output, target as arguments. update_batches (int): Number of batches to accumulate before taking an optimizer step. Defaults to 1 . log_save_batches (int): How often to call :func:labml.tracker.save .

421@option(TrainValidConfigs.trainer)422def\_default\_trainer(c:TrainValidConfigs):423returnTrainer(name='Train',424mode=c.mode,425data\_loader=c.train\_loader,426inner\_iterations=c.inner\_iterations,427state\_modules=c.state\_modules,428is\_track\_time=c.is\_track\_time,429step=c.step)430431432@option(TrainValidConfigs.validator)433def\_default\_validator(c:TrainValidConfigs):434returnTrainer(name='Valid',435mode=c.mode,436data\_loader=c.valid\_loader,437inner\_iterations=c.inner\_iterations,438state\_modules=c.state\_modules,439is\_track\_time=c.is\_track\_time,440step=c.step)441442443@option(TrainValidConfigs.loop\_count)444def\_data\_loop\_count(c:TrainValidConfigs):445returnc.epochs446447448classSimpleTrainValidConfigs(TrainValidConfigs):

#

462optimizer:torch.optim.Adam463model:nn.Module464device:torch.device=DeviceConfigs()465466loss\_func:nn.Module467468update\_batches:int=1469log\_save\_batches:int=1470471state\_modules:List[StateModule]=[]

#

473definit(self):474pass

#

476defstep(self,batch:Any,batch\_idx:BatchIndex):477self.model.train(self.mode.is\_train)478data,target=batch[0].to(self.device),batch[1].to(self.device)479480ifself.mode.is\_train:481tracker.add\_global\_step(len(data))482483withmonit.section("model"):484output=self.model(data)485486loss=self.loss\_func(output,target)487tracker.add("loss.",loss)488489ifself.mode.is\_train:490withmonit.section('backward'):491loss.backward()492493ifbatch\_idx.is\_interval(self.update\_batches):494withmonit.section('optimize'):495self.optimizer.step()496self.optimizer.zero\_grad()497498ifbatch\_idx.is\_interval(self.log\_save\_batches):499tracker.save()500501502meta\_config(SimpleTrainValidConfigs.update\_batches,503)

#

506@option(SimpleTrainValidConfigs.optimizer)507def\_default\_optimizer(c:SimpleTrainValidConfigs):508from.optimizerimportOptimizerConfigs509opt\_conf=OptimizerConfigs()510opt\_conf.parameters=c.model.parameters()511returnopt\_conf

labml.ai