Back to Annotated Deep Learning Paper Implementations

Recurrent Highway Networks

docs/recurrent_highway_networks/index.html

latest5.1 KB
Original Source

homerecurrent_highway_networks

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/recurrent_highway_networks/ init.py)

#

Recurrent Highway Networks

This is a PyTorch implementation of Recurrent Highway Networks.

11fromtypingimportOptional1213importtorch14fromtorchimportnn

#

Recurrent Highway Network Cell

This implements equations (6)−(9).

sdt​=hdt​⊙gdt​+sd−1t​⊙cdt​

where

h0t​g0t​c0t​​=tanh(linhx​(x)+linhs​(sDt−1​))=σ(lingx​(x)+lings​1(sDt−1​))=σ(lincx​(x)+lincs1​(sDt−1​))​

and for 0<d<D

hdt​gdt​cdt​​=tanh(linhs​d(sdt​))=σ(lings​d(sdt​))=σ(lincsd​(sdt​))​

⊙ stands for element-wise multiplication.

Here we have made a couple of changes to notations from the paper. To avoid confusion with time, gate is represented with g, which was t in the paper. To avoid confusion with multiple layers we use d for depth and D for total depth instead of l and L from the paper.

We have also replaced the weight matrices and bias vectors from the equations with linear transforms, because that's how the implementation is going to look like.

We implement weight tying, as described in paper, cdt​=1−gdt​.

18classRHNCell(nn.Module):

#

input_size is the feature length of the input and hidden_size is the feature length of the cell. depth is D.

56def\_\_init\_\_(self,input\_size:int,hidden\_size:int,depth:int):

#

62super().\_\_init\_\_()6364self.hidden\_size=hidden\_size65self.depth=depth

#

We combine linhs​ and lings​, with a single linear layer. We can then split the results to get the linhs​ and lings​ components. This is the linhs​d and lings​d for 0≤d<D.

69self.hidden\_lin=nn.ModuleList([nn.Linear(hidden\_size,2\*hidden\_size)for\_inrange(depth)])

#

Similarly we combine linhx​ and lingx​.

72self.input\_lin=nn.Linear(input\_size,2\*hidden\_size,bias=False)

#

x has shape [batch_size, input_size] and s has shape [batch_size, hidden_size] .

74defforward(self,x:torch.Tensor,s:torch.Tensor):

#

Iterate 0≤d<D

81fordinrange(self.depth):

#

We calculate the concatenation of linear transforms for h and g

83ifd==0:

#

The input is used only when d is 0.

85hg=self.input\_lin(x)+self.hidden\_lin[d](s)86else:87hg=self.hidden\_lin[d](s)

#

Use the first half of hg to get hdt​

h0t​hdt​​=tanh(linhx​(x)+linhs​(sDt−1​))=tanh(linhs​d(sdt​))​

95h=torch.tanh(hg[:,:self.hidden\_size])

#

Use the second half of hg to get gdt​

g0t​gdt​​=σ(lingx​(x)+lings​1(sDt−1​))=σ(lings​d(sdt​))​

102g=torch.sigmoid(hg[:,self.hidden\_size:])103104s=h\*g+s\*(1-g)105106returns

#

Multilayer Recurrent Highway Network

109classRHN(nn.Module):

#

Create a network of n_layers of recurrent highway network layers, each with depth depth , D.

114def\_\_init\_\_(self,input\_size:int,hidden\_size:int,depth:int,n\_layers:int):

#

119super().\_\_init\_\_()120self.n\_layers=n\_layers121self.hidden\_size=hidden\_size

#

Create cells for each layer. Note that only the first layer gets the input directly. Rest of the layers get the input from the layer below

124self.cells=nn.ModuleList([RHNCell(input\_size,hidden\_size,depth)]+125[RHNCell(hidden\_size,hidden\_size,depth)for\_inrange(n\_layers-1)])

#

x has shape [seq_len, batch_size, input_size] and state has shape [batch_size, hidden_size] .

127defforward(self,x:torch.Tensor,state:Optional[torch.Tensor]=None):

#

132time\_steps,batch\_size=x.shape[:2]

#

Initialize the state if None

135ifstateisNone:136s=[x.new\_zeros(batch\_size,self.hidden\_size)for\_inrange(self.n\_layers)]137else:

#

Reverse stack the state to get the state of each layer

📝 You can just work with the tensor itself but this is easier to debug

141s=torch.unbind(state)

#

Array to collect the outputs of the final layer at each time step.

144out=[]

#

Run through the network for each time step

147fortinrange(time\_steps):

#

Input to the first layer is the input itself

149inp=x[t]

#

Loop through the layers

151forlayerinrange(self.n\_layers):

#

Get the state of the layer

153s[layer]=self.cells[layer](inp,s[layer])

#

Input to the next layer is the state of this layer

155inp=s[layer]

#

Collect the output of the final layer

157out.append(s[-1])

#

Stack the outputs and states

160out=torch.stack(out)161s=torch.stack(s)162163returnout,s

labml.ai