Back to Annotated Deep Learning Paper Implementations

PonderNet: Learning to Ponder

docs/adaptive_computation/ponder_net/index.html

latest9.6 KB
Original Source

homeadaptive_computationponder_net

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/adaptive_computation/ponder_net/ init.py)

#

PonderNet: Learning to Ponder

This is a PyTorch implementation of the paper PonderNet: Learning to Ponder.

PonderNet adapts the computation based on the input. It changes the number of steps to take on a recurrent network based on the input. PonderNet learns this with end-to-end gradient descent.

PonderNet has a step function of the form

y^​n​,hn+1​,λn​=s(x,hn​)

where x is the input, hn​ is the state, y^​n​ is the prediction at step n, and λn​ is the probability of halting (stopping) at current step.

s can be any neural network (e.g. LSTM, MLP, GRU, Attention layer).

The unconditioned probability of halting at step n is then,

pn​=λn​j=1∏n−1​(1−λj​)

That is the probability of not being halted at any of the previous steps and halting at step n.

During inference, we halt by sampling based on the halting probability λn​ and get the prediction at the halting layer y^​n​ as the final output.

During training, we get the predictions from all the layers and calculate the losses for each of them. And then take the weighted average of the losses based on the probabilities of getting halted at each layer pn​.

The step function is applied to a maximum number of steps donated by N.

The overall loss of PonderNet is

LLRec​LReg​​=LRec​+βLReg​=n=1∑N​pn​L(y,y^​n​)=KL(pn​∥pG​(λp​))​

L is the normal loss function between target y and prediction y^​n​.

KL is the Kullback–Leibler divergence.

pG​ is the Geometric distribution parameterized by λp​. λp​ has nothing to do with λn​; we are just sticking to same notation as the paper. PrpG​(λp​)​(X=k)=(1−λp​)kλp​.

The regularization loss biases the network towards taking λp​1​ steps and incentivizes non-zero probabilities for all steps; i.e. promotes exploration.

Here is the training code experiment.py to train a PonderNet on Parity Task.

63fromtypingimportTuple6465importtorch66fromtorchimportnn

#

PonderNet with GRU for Parity Task

This is a simple model that uses a GRU Cell as the step function.

This model is for the Parity Task where the input is a vector of n_elems . Each element of the vector is either 0 , 1 or -1 and the output is the parity - a binary value that is true if the number of 1 s is odd and false otherwise.

The prediction of the model is the log probability of the parity being 1.

70classParityPonderGRU(nn.Module):

#

  • n_elems is the number of elements in the input vector
  • n_hidden is the state vector size of the GRU
  • max_steps is the maximum number of steps N
84def\_\_init\_\_(self,n\_elems:int,n\_hidden:int,max\_steps:int):

#

90super().\_\_init\_\_()9192self.max\_steps=max\_steps93self.n\_hidden=n\_hidden

#

GRU hn+1​=sh​(x,hn​)

97self.gru=nn.GRUCell(n\_elems,n\_hidden)

#

y^​n​=sy​(hn​) We could use a layer that takes the concatenation of h and x as input but we went with this for simplicity.

101self.output\_layer=nn.Linear(n\_hidden,1)

#

λn​=sλ​(hn​)

103self.lambda\_layer=nn.Linear(n\_hidden,1)104self.lambda\_prob=nn.Sigmoid()

#

An option to set during inference so that computation is actually halted at inference time

106self.is\_halt=False

#

  • x is the input of shape [batch_size, n_elems]

This outputs a tuple of four tensors:

  1. p1​…pN​ in a tensor of shape [N, batch_size] 2. y^​1​…y^​N​ in a tensor of shape [N, batch_size] - the log probabilities of the parity being 1 3. pm​ of shape [batch_size] 4. y^​m​ of shape [batch_size] where the computation was halted at step m
108defforward(self,x:torch.Tensor)-\>Tuple[torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor]:

#

121batch\_size=x.shape[0]

#

We get initial state h1​=sh​(x)

124h=x.new\_zeros((x.shape[0],self.n\_hidden))125h=self.gru(x,h)

#

Lists to store p1​…pN​ and y^​1​…y^​N​

128p=[]129y=[]

#

∏j=1n−1​(1−λj​)

131un\_halted\_prob=h.new\_ones((batch\_size,))

#

A vector to maintain which samples has halted computation

134halted=h.new\_zeros((batch\_size,))

#

pm​ and y^​m​ where the computation was halted at step m

136p\_m=h.new\_zeros((batch\_size,))137y\_m=h.new\_zeros((batch\_size,))

#

Iterate for N steps

140forninrange(1,self.max\_steps+1):

#

The halting probability λN​=1 for the last step

142ifn==self.max\_steps:143lambda\_n=h.new\_ones(h.shape[0])

#

λn​=sλ​(hn​)

145else:146lambda\_n=self.lambda\_prob(self.lambda\_layer(h))[:,0]

#

y^​n​=sy​(hn​)

148y\_n=self.output\_layer(h)[:,0]

#

pn​=λn​j=1∏n−1​(1−λj​)

151p\_n=un\_halted\_prob\*lambda\_n

#

Update ∏j=1n−1​(1−λj​)

153un\_halted\_prob=un\_halted\_prob\*(1-lambda\_n)

#

Halt based on halting probability λn​

156halt=torch.bernoulli(lambda\_n)\*(1-halted)

#

Collect pn​ and y^​n​

159p.append(p\_n)160y.append(y\_n)

#

Update pm​ and y^​m​ based on what was halted at current step n

163p\_m=p\_m\*(1-halt)+p\_n\*halt164y\_m=y\_m\*(1-halt)+y\_n\*halt

#

Update halted samples

167halted=halted+halt

#

Get next state hn+1​=sh​(x,hn​)

169h=self.gru(x,h)

#

Stop the computation if all samples have halted

172ifself.is\_haltandhalted.sum()==batch\_size:173break

#

176returntorch.stack(p),torch.stack(y),p\_m,y\_m

#

Reconstruction loss

LRec​=n=1∑N​pn​L(y,y^​n​)

L is the normal loss function between target y and prediction y^​n​.

179classReconstructionLoss(nn.Module):

#

  • loss_func is the loss function L
188def\_\_init\_\_(self,loss\_func:nn.Module):

#

192super().\_\_init\_\_()193self.loss\_func=loss\_func

#

  • p is p1​…pN​ in a tensor of shape [N, batch_size]
  • y_hat is y^​1​…y^​N​ in a tensor of shape [N, batch_size, ...]
  • y is the target of shape [batch_size, ...]
195defforward(self,p:torch.Tensor,y\_hat:torch.Tensor,y:torch.Tensor):

#

The total ∑n=1N​pn​L(y,y^​n​)

203total\_loss=p.new\_tensor(0.)

#

Iterate upto N

205forninrange(p.shape[0]):

#

pn​L(y,y^​n​) for each sample and the mean of them

207loss=(p[n]\*self.loss\_func(y\_hat[n],y)).mean()

#

Add to total loss

209total\_loss=total\_loss+loss

#

212returntotal\_loss

#

Regularization loss

LReg​=KL(pn​∥pG​(λp​))

KL is the Kullback–Leibler divergence.

pG​ is the Geometric distribution parameterized by λp​. λp​ has nothing to do with λn​; we are just sticking to same notation as the paper. PrpG​(λp​)​(X=k)=(1−λp​)kλp​.

The regularization loss biases the network towards taking λp​1​ steps and incentivies non-zero probabilities for all steps; i.e. promotes exploration.

215classRegularizationLoss(nn.Module):

#

  • lambda_p is λp​ - the success probability of geometric distribution
  • max_steps is the highest N; we use this to pre-compute pG​(λp​)
231def\_\_init\_\_(self,lambda\_p:float,max\_steps:int=1\_000):

#

236super().\_\_init\_\_()

#

Empty vector to calculate pG​(λp​)

239p\_g=torch.zeros((max\_steps,))

#

(1−λp​)k

241not\_halted=1.

#

Iterate upto max_steps

243forkinrange(max\_steps):

#

PrpG​(λp​)​(X=k)=(1−λp​)kλp​

245p\_g[k]=not\_halted\*lambda\_p

#

Update (1−λp​)k

247not\_halted=not\_halted\*(1-lambda\_p)

#

Save PrpG​(λp​)​

250self.p\_g=nn.Parameter(p\_g,requires\_grad=False)

#

KL-divergence loss

253self.kl\_div=nn.KLDivLoss(reduction='batchmean')

#

  • p is p1​…pN​ in a tensor of shape [N, batch_size]
255defforward(self,p:torch.Tensor):

#

Transpose p to [batch_size, N]

260p=p.transpose(0,1)

#

Get PrpG​(λp​)​ upto N and expand it across the batch dimension

262p\_g=self.p\_g[None,:p.shape[1]].expand\_as(p)

#

Calculate the KL-divergence. The PyTorch KL-divergence implementation accepts log probabilities.

267returnself.kl\_div(p.log(),p\_g)

labml.ai