docs/adaptive_computation/ponder_net/experiment.html
homeadaptive_computationponder_net
This trains a PonderNet on Parity Task.
13fromtypingimportAny1415importtorch16fromtorchimportnn17fromtorch.utils.dataimportDataLoader1819fromlabmlimporttracker,experiment20fromlabml\_nn.helpers.metricsimportAccuracyDirect21fromlabml\_nn.helpers.trainerimportSimpleTrainValidConfigs,BatchIndex22fromlabml\_nn.adaptive\_computation.parityimportParityDataset23fromlabml\_nn.adaptive\_computation.ponder\_netimportParityPonderGRU,ReconstructionLoss,RegularizationLoss
Configurations with a simple training loop
26classConfigs(SimpleTrainValidConfigs):
Number of epochs
33epochs:int=100
Number of batches per epoch
35n\_batches:int=500
Batch size
37batch\_size:int=128
Model
40model:ParityPonderGRU
LRec
43loss\_rec:ReconstructionLoss
LReg
45loss\_reg:RegularizationLoss
The number of elements in the input vector. We keep it low for demonstration; otherwise, training takes a lot of time. Although the parity task seems simple, figuring out the pattern by looking at samples is quite hard.
51n\_elems:int=8
Number of units in the hidden layer (state)
53n\_hidden:int=64
Maximum number of steps N
55max\_steps:int=20
λp for the geometric distribution pG(λp)
58lambda\_p:float=0.2
Regularization loss LReg coefficient β
60beta:float=0.01
Gradient clipping by norm
63grad\_norm\_clip:float=1.0
Training and validation loaders
66train\_loader:DataLoader67valid\_loader:DataLoader
Accuracy calculator
70accuracy=AccuracyDirect()
72definit(self):
Print indicators to screen
74tracker.set\_scalar('loss.\*',True)75tracker.set\_scalar('loss\_reg.\*',True)76tracker.set\_scalar('accuracy.\*',True)77tracker.set\_scalar('steps.\*',True)
We need to set the metrics to calculate them for the epoch for training and validation
80self.state\_modules=[self.accuracy]
Initialize the model
83self.model=ParityPonderGRU(self.n\_elems,self.n\_hidden,self.max\_steps).to(self.device)
LRec
85self.loss\_rec=ReconstructionLoss(nn.BCEWithLogitsLoss(reduction='none')).to(self.device)
LReg
87self.loss\_reg=RegularizationLoss(self.lambda\_p,self.max\_steps).to(self.device)
Training and validation loaders
90self.train\_loader=DataLoader(ParityDataset(self.batch\_size\*self.n\_batches,self.n\_elems),91batch\_size=self.batch\_size)92self.valid\_loader=DataLoader(ParityDataset(self.batch\_size\*32,self.n\_elems),93batch\_size=self.batch\_size)
This method gets called by the trainer for each batch
95defstep(self,batch:Any,batch\_idx:BatchIndex):
Set the model mode
100self.model.train(self.mode.is\_train)
Get the input and labels and move them to the model's device
103data,target=batch[0].to(self.device),batch[1].to(self.device)
Increment step in training mode
106ifself.mode.is\_train:107tracker.add\_global\_step(len(data))
Run the model
110p,y\_hat,p\_sampled,y\_hat\_sampled=self.model(data)
Calculate the reconstruction loss
113loss\_rec=self.loss\_rec(p,y\_hat,target.to(torch.float))114tracker.add("loss.",loss\_rec)
Calculate the regularization loss
117loss\_reg=self.loss\_reg(p)118tracker.add("loss\_reg.",loss\_reg)
L=LRec+βLReg
121loss=loss\_rec+self.beta\*loss\_reg
Calculate the expected number of steps taken
124steps=torch.arange(1,p.shape[0]+1,device=p.device)125expected\_steps=(p\*steps[:,None]).sum(dim=0)126tracker.add("steps.",expected\_steps)
Call accuracy metric
129self.accuracy(y\_hat\_sampled\>0,target)130131ifself.mode.is\_train:
Compute gradients
133loss.backward()
Clip gradients
135torch.nn.utils.clip\_grad\_norm\_(self.model.parameters(),max\_norm=self.grad\_norm\_clip)
Optimizer
137self.optimizer.step()
Clear gradients
139self.optimizer.zero\_grad()
141tracker.save()
Run the experiment
144defmain():
148experiment.create(name='ponder\_net')149150conf=Configs()151experiment.configs(conf,{152'optimizer.optimizer':'Adam',153'optimizer.learning\_rate':0.0003,154})155156withexperiment.start():157conf.run()
160if\_\_name\_\_=='\_\_main\_\_':161main()