Back to Annotated Deep Learning Paper Implementations

HyperNetworks - HyperLSTM

docs/hypernetworks/hyper_lstm.html

latest10.1 KB
Original Source

homehypernetworks

View code on Github

#

HyperNetworks - HyperLSTM

We have implemented HyperLSTM introduced in paper HyperNetworks, with annotations using PyTorch. This blog post by David Ha gives a good explanation of HyperNetworks.

We have an experiment that trains a HyperLSTM to predict text on Shakespeare dataset. Here's the link to code: experiment.py

HyperNetworks use a smaller network to generate weights of a larger network. There are two variants: static hyper-networks and dynamic hyper-networks. Static HyperNetworks have smaller networks that generate weights (kernels) of a convolutional network. Dynamic HyperNetworks generate parameters of a recurrent neural network for each step. This is an implementation of the latter.

Dynamic HyperNetworks

In a RNN the parameters stay constant for each step. Dynamic HyperNetworks generate different parameters for each step. HyperLSTM has the structure of a LSTM but the parameters of each step are changed by a smaller LSTM network.

In the basic form, a Dynamic HyperNetwork has a smaller recurrent network that generates a feature vector corresponding to each parameter tensor of the larger recurrent network. Let's say the larger network has some parameter Wh​ the smaller network generates a feature vector zh​ and we dynamically compute Wh​ as a linear transformation of zh​. For instance Wh​=⟨Whz​,zh​⟩ where Whz​ is a 3-d tensor parameter and ⟨.⟩ is a tensor-vector multiplication. zh​ is usually a linear transformation of the output of the smaller recurrent network.

Weight scaling instead of computing

Large recurrent networks have large dynamically computed parameters. These are calculated using linear transformation of feature vector z. And this transformation requires an even larger weight tensor. That is, when Wh​ has shape Nh​×Nh​, Whz​ will be Nh​×Nh​×Nz​.

To overcome this, we compute the weight parameters of the recurrent network by dynamically scaling each row of a matrix of same size.

d(z)=Whz​zh​Wh​=⎝⎛​d0​(z)Whd0​​d1​(z)Whd1​​...dNh​​(z)WhdNh​​​​⎠⎞​​

where Whd​ is a Nh​×Nh​ parameter matrix.

We can further optimize this when we compute Wh​h, as d(z)⊙(Whd​h) where ⊙ stands for element-wise multiplication.

72fromtypingimportOptional,Tuple7374importtorch75fromtorchimportnn7677fromlabml\_nn.lstmimportLSTMCell

#

HyperLSTM Cell

For HyperLSTM the smaller network and the larger network both have the LSTM structure. This is defined in Appendix A.2.2 in the paper.

80classHyperLSTMCell(nn.Module):

#

input_size is the size of the input xt​, hidden_size is the size of the LSTM, and hyper_size is the size of the smaller LSTM that alters the weights of the larger outer LSTM. n_z is the size of the feature vectors used to alter the LSTM weights.

We use the output of the smaller LSTM to compute zh​i,f,g,o, zxi,f,g,o​ and zbi,f,g,o​ using linear transformations. We calculate dhi,f,g,o​(zh​i,f,g,o), dxi,f,g,o​(zxi,f,g,o​), and dbi,f,g,o​(zbi,f,g,o​) from these, using linear transformations again. These are then used to scale the rows of weight and bias tensors of the main LSTM.

📝 Since the computation of z and d are two sequential linear transformations these can be combined into a single linear transformation. However we've implemented this separately so that it matches with the description in the paper.

88def\_\_init\_\_(self,input\_size:int,hidden\_size:int,hyper\_size:int,n\_z:int):

#

106super().\_\_init\_\_()

#

The input to the hyperLSTM is x^t​=(ht−1​xt​​) where xt​ is the input and ht−1​ is the output of the outer LSTM at previous step. So the input size is hidden_size + input_size .

The output of hyperLSTM is h^t​ and c^t​.

119self.hyper=LSTMCell(hidden\_size+input\_size,hyper\_size,layer\_norm=True)

#

zh​i,f,g,o=linhi,f,g,o​(h^t​) 🤔 In the paper it was specified as zh​i,f,g,o=linhi,f,g,o​(h^t−1​) I feel that it's a typo.

125self.z\_h=nn.Linear(hyper\_size,4\*n\_z)

#

zxi,f,g,o​=linxi,f,g,o​(h^t​)

127self.z\_x=nn.Linear(hyper\_size,4\*n\_z)

#

zbi,f,g,o​=linbi,f,g,o​(h^t​)

129self.z\_b=nn.Linear(hyper\_size,4\*n\_z,bias=False)

#

dhi,f,g,o​(zh​i,f,g,o)=lindhi,f,g,o​(zh​i,f,g,o)

132d\_h=[nn.Linear(n\_z,hidden\_size,bias=False)for\_inrange(4)]133self.d\_h=nn.ModuleList(d\_h)

#

dxi,f,g,o​(zxi,f,g,o​)=lindxi,f,g,o​(zxi,f,g,o​)

135d\_x=[nn.Linear(n\_z,hidden\_size,bias=False)for\_inrange(4)]136self.d\_x=nn.ModuleList(d\_x)

#

dbi,f,g,o​(zbi,f,g,o​)=lindbi,f,g,o​(zbi,f,g,o​)

138d\_b=[nn.Linear(n\_z,hidden\_size)for\_inrange(4)]139self.d\_b=nn.ModuleList(d\_b)

#

The weight matrices Whi,f,g,o​

142self.w\_h=nn.ParameterList([nn.Parameter(torch.zeros(hidden\_size,hidden\_size))for\_inrange(4)])

#

The weight matrices Wxi,f,g,o​

144self.w\_x=nn.ParameterList([nn.Parameter(torch.zeros(hidden\_size,input\_size))for\_inrange(4)])

#

Layer normalization

147self.layer\_norm=nn.ModuleList([nn.LayerNorm(hidden\_size)for\_inrange(4)])148self.layer\_norm\_c=nn.LayerNorm(hidden\_size)

#

150defforward(self,x:torch.Tensor,151h:torch.Tensor,c:torch.Tensor,152h\_hat:torch.Tensor,c\_hat:torch.Tensor):

#

x^t​=(ht−1​xt​​)

159x\_hat=torch.cat((h,x),dim=-1)

#

h^t​,c^t​=lstm(x^t​,h^t−1​,c^t−1​)

161h\_hat,c\_hat=self.hyper(x\_hat,h\_hat,c\_hat)

#

zh​i,f,g,o=linhi,f,g,o​(h^t​)

164z\_h=self.z\_h(h\_hat).chunk(4,dim=-1)

#

zxi,f,g,o​=linxi,f,g,o​(h^t​)

166z\_x=self.z\_x(h\_hat).chunk(4,dim=-1)

#

zbi,f,g,o​=linbi,f,g,o​(h^t​)

168z\_b=self.z\_b(h\_hat).chunk(4,dim=-1)

#

We calculate i, f, g and o in a loop

171ifgo=[]172foriinrange(4):

#

dhi,f,g,o​(zh​i,f,g,o)=lindhi,f,g,o​(zh​i,f,g,o)

174d\_h=self.d\_h[i](z\_h[i])

#

dxi,f,g,o​(zxi,f,g,o​)=lindxi,f,g,o​(zxi,f,g,o​)

176d\_x=self.d\_x[i](z\_x[i])

# i,f,g,o=LN(++​dhi,f,g,o​(zh​)⊙(Whi,f,g,o​ht−1​)dxi,f,g,o​(zx​)⊙(Whi,f,g,o​xt​)dbi,f,g,o​(zb​))​

183y=d\_h\*torch.einsum('ij,bj-\>bi',self.w\_h[i],h)+\184d\_x\*torch.einsum('ij,bj-\>bi',self.w\_x[i],x)+\185self.d\_b[i](z\_b[i])186187ifgo.append(self.layer\_norm[i](y))

#

it​,ft​,gt​,ot​

190i,f,g,o=ifgo

#

ct​=σ(ft​)⊙ct−1​+σ(it​)⊙tanh(gt​)

193c\_next=torch.sigmoid(f)\*c+torch.sigmoid(i)\*torch.tanh(g)

#

ht​=σ(ot​)⊙tanh(LN(ct​))

196h\_next=torch.sigmoid(o)\*torch.tanh(self.layer\_norm\_c(c\_next))197198returnh\_next,c\_next,h\_hat,c\_hat

#

HyperLSTM module

201classHyperLSTM(nn.Module):

#

Create a network of n_layers of HyperLSTM.

206def\_\_init\_\_(self,input\_size:int,hidden\_size:int,hyper\_size:int,n\_z:int,n\_layers:int):

#

211super().\_\_init\_\_()

#

Store sizes to initialize state

214self.n\_layers=n\_layers215self.hidden\_size=hidden\_size216self.hyper\_size=hyper\_size

#

Create cells for each layer. Note that only the first layer gets the input directly. Rest of the layers get the input from the layer below

220self.cells=nn.ModuleList([HyperLSTMCell(input\_size,hidden\_size,hyper\_size,n\_z)]+221[HyperLSTMCell(hidden\_size,hidden\_size,hyper\_size,n\_z)for\_in222range(n\_layers-1)])

#

  • x has shape [n_steps, batch_size, input_size] and
  • state is a tuple of h,c,h^,c^. h,c have shape [batch_size, hidden_size] and h^,c^ have shape [batch_size, hyper_size] .
224defforward(self,x:torch.Tensor,225state:Optional[Tuple[torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor]]=None):

#

232n\_steps,batch\_size=x.shape[:2]

#

Initialize the state with zeros if None

235ifstateisNone:236h=[x.new\_zeros(batch\_size,self.hidden\_size)for\_inrange(self.n\_layers)]237c=[x.new\_zeros(batch\_size,self.hidden\_size)for\_inrange(self.n\_layers)]238h\_hat=[x.new\_zeros(batch\_size,self.hyper\_size)for\_inrange(self.n\_layers)]239c\_hat=[x.new\_zeros(batch\_size,self.hyper\_size)for\_inrange(self.n\_layers)]

#

241else:242(h,c,h\_hat,c\_hat)=state

#

Reverse stack the tensors to get the states of each layer

📝 You can just work with the tensor itself but this is easier to debug

246h,c=list(torch.unbind(h)),list(torch.unbind(c))247h\_hat,c\_hat=list(torch.unbind(h\_hat)),list(torch.unbind(c\_hat))

#

Collect the outputs of the final layer at each step

250out=[]251fortinrange(n\_steps):

#

Input to the first layer is the input itself

253inp=x[t]

#

Loop through the layers

255forlayerinrange(self.n\_layers):

#

Get the state of the layer

257h[layer],c[layer],h\_hat[layer],c\_hat[layer]=\258self.cells[layer](inp,h[layer],c[layer],h\_hat[layer],c\_hat[layer])

#

Input to the next layer is the state of this layer

260inp=h[layer]

#

Collect the output h of the final layer

262out.append(h[-1])

#

Stack the outputs and states

265out=torch.stack(out)266h=torch.stack(h)267c=torch.stack(c)268h\_hat=torch.stack(h\_hat)269c\_hat=torch.stack(c\_hat)

#

272returnout,(h,c,h\_hat,c\_hat)

labml.ai