Back to Annotated Deep Learning Paper Implementations

PonderNet Parity Task Experiment

docs/adaptive_computation/ponder_net/experiment.html

latest5.3 KB
Original Source

homeadaptive_computationponder_net

View code on Github

#

PonderNet Parity Task Experiment

This trains a PonderNet on Parity Task.

13fromtypingimportAny1415importtorch16fromtorchimportnn17fromtorch.utils.dataimportDataLoader1819fromlabmlimporttracker,experiment20fromlabml\_nn.helpers.metricsimportAccuracyDirect21fromlabml\_nn.helpers.trainerimportSimpleTrainValidConfigs,BatchIndex22fromlabml\_nn.adaptive\_computation.parityimportParityDataset23fromlabml\_nn.adaptive\_computation.ponder\_netimportParityPonderGRU,ReconstructionLoss,RegularizationLoss

#

Configurations with a simple training loop

26classConfigs(SimpleTrainValidConfigs):

#

Number of epochs

33epochs:int=100

#

Number of batches per epoch

35n\_batches:int=500

#

Batch size

37batch\_size:int=128

#

Model

40model:ParityPonderGRU

#

LRec​

43loss\_rec:ReconstructionLoss

#

LReg​

45loss\_reg:RegularizationLoss

#

The number of elements in the input vector. We keep it low for demonstration; otherwise, training takes a lot of time. Although the parity task seems simple, figuring out the pattern by looking at samples is quite hard.

51n\_elems:int=8

#

Number of units in the hidden layer (state)

53n\_hidden:int=64

#

Maximum number of steps N

55max\_steps:int=20

#

λp​ for the geometric distribution pG​(λp​)

58lambda\_p:float=0.2

#

Regularization loss LReg​ coefficient β

60beta:float=0.01

#

Gradient clipping by norm

63grad\_norm\_clip:float=1.0

#

Training and validation loaders

66train\_loader:DataLoader67valid\_loader:DataLoader

#

Accuracy calculator

70accuracy=AccuracyDirect()

#

72definit(self):

#

Print indicators to screen

74tracker.set\_scalar('loss.\*',True)75tracker.set\_scalar('loss\_reg.\*',True)76tracker.set\_scalar('accuracy.\*',True)77tracker.set\_scalar('steps.\*',True)

#

We need to set the metrics to calculate them for the epoch for training and validation

80self.state\_modules=[self.accuracy]

#

Initialize the model

83self.model=ParityPonderGRU(self.n\_elems,self.n\_hidden,self.max\_steps).to(self.device)

#

LRec​

85self.loss\_rec=ReconstructionLoss(nn.BCEWithLogitsLoss(reduction='none')).to(self.device)

#

LReg​

87self.loss\_reg=RegularizationLoss(self.lambda\_p,self.max\_steps).to(self.device)

#

Training and validation loaders

90self.train\_loader=DataLoader(ParityDataset(self.batch\_size\*self.n\_batches,self.n\_elems),91batch\_size=self.batch\_size)92self.valid\_loader=DataLoader(ParityDataset(self.batch\_size\*32,self.n\_elems),93batch\_size=self.batch\_size)

#

This method gets called by the trainer for each batch

95defstep(self,batch:Any,batch\_idx:BatchIndex):

#

Set the model mode

100self.model.train(self.mode.is\_train)

#

Get the input and labels and move them to the model's device

103data,target=batch[0].to(self.device),batch[1].to(self.device)

#

Increment step in training mode

106ifself.mode.is\_train:107tracker.add\_global\_step(len(data))

#

Run the model

110p,y\_hat,p\_sampled,y\_hat\_sampled=self.model(data)

#

Calculate the reconstruction loss

113loss\_rec=self.loss\_rec(p,y\_hat,target.to(torch.float))114tracker.add("loss.",loss\_rec)

#

Calculate the regularization loss

117loss\_reg=self.loss\_reg(p)118tracker.add("loss\_reg.",loss\_reg)

#

L=LRec​+βLReg​

121loss=loss\_rec+self.beta\*loss\_reg

#

Calculate the expected number of steps taken

124steps=torch.arange(1,p.shape[0]+1,device=p.device)125expected\_steps=(p\*steps[:,None]).sum(dim=0)126tracker.add("steps.",expected\_steps)

#

Call accuracy metric

129self.accuracy(y\_hat\_sampled\>0,target)130131ifself.mode.is\_train:

#

Compute gradients

133loss.backward()

#

Clip gradients

135torch.nn.utils.clip\_grad\_norm\_(self.model.parameters(),max\_norm=self.grad\_norm\_clip)

#

Optimizer

137self.optimizer.step()

#

Clear gradients

139self.optimizer.zero\_grad()

#

141tracker.save()

#

Run the experiment

144defmain():

#

148experiment.create(name='ponder\_net')149150conf=Configs()151experiment.configs(conf,{152'optimizer.optimizer':'Adam',153'optimizer.learning\_rate':0.0003,154})155156withexperiment.start():157conf.run()

#

160if\_\_name\_\_=='\_\_main\_\_':161main()

labml.ai