docs/recurrent_highway_networks/index.html
homerecurrent_highway_networks
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/recurrent_highway_networks/ init.py)
This is a PyTorch implementation of Recurrent Highway Networks.
11fromtypingimportOptional1213importtorch14fromtorchimportnn
This implements equations (6)−(9).
sdt=hdt⊙gdt+sd−1t⊙cdt
where
h0tg0tc0t=tanh(linhx(x)+linhs(sDt−1))=σ(lingx(x)+lings1(sDt−1))=σ(lincx(x)+lincs1(sDt−1))
and for 0<d<D
hdtgdtcdt=tanh(linhsd(sdt))=σ(lingsd(sdt))=σ(lincsd(sdt))
⊙ stands for element-wise multiplication.
Here we have made a couple of changes to notations from the paper. To avoid confusion with time, gate is represented with g, which was t in the paper. To avoid confusion with multiple layers we use d for depth and D for total depth instead of l and L from the paper.
We have also replaced the weight matrices and bias vectors from the equations with linear transforms, because that's how the implementation is going to look like.
We implement weight tying, as described in paper, cdt=1−gdt.
18classRHNCell(nn.Module):
input_size is the feature length of the input and hidden_size is the feature length of the cell. depth is D.
56def\_\_init\_\_(self,input\_size:int,hidden\_size:int,depth:int):
62super().\_\_init\_\_()6364self.hidden\_size=hidden\_size65self.depth=depth
We combine linhs and lings, with a single linear layer. We can then split the results to get the linhs and lings components. This is the linhsd and lingsd for 0≤d<D.
69self.hidden\_lin=nn.ModuleList([nn.Linear(hidden\_size,2\*hidden\_size)for\_inrange(depth)])
Similarly we combine linhx and lingx.
72self.input\_lin=nn.Linear(input\_size,2\*hidden\_size,bias=False)
x has shape [batch_size, input_size] and s has shape [batch_size, hidden_size] .
74defforward(self,x:torch.Tensor,s:torch.Tensor):
Iterate 0≤d<D
81fordinrange(self.depth):
We calculate the concatenation of linear transforms for h and g
83ifd==0:
The input is used only when d is 0.
85hg=self.input\_lin(x)+self.hidden\_lin[d](s)86else:87hg=self.hidden\_lin[d](s)
Use the first half of hg to get hdt
h0thdt=tanh(linhx(x)+linhs(sDt−1))=tanh(linhsd(sdt))
95h=torch.tanh(hg[:,:self.hidden\_size])
Use the second half of hg to get gdt
g0tgdt=σ(lingx(x)+lings1(sDt−1))=σ(lingsd(sdt))
102g=torch.sigmoid(hg[:,self.hidden\_size:])103104s=h\*g+s\*(1-g)105106returns
109classRHN(nn.Module):
Create a network of n_layers of recurrent highway network layers, each with depth depth , D.
114def\_\_init\_\_(self,input\_size:int,hidden\_size:int,depth:int,n\_layers:int):
119super().\_\_init\_\_()120self.n\_layers=n\_layers121self.hidden\_size=hidden\_size
Create cells for each layer. Note that only the first layer gets the input directly. Rest of the layers get the input from the layer below
124self.cells=nn.ModuleList([RHNCell(input\_size,hidden\_size,depth)]+125[RHNCell(hidden\_size,hidden\_size,depth)for\_inrange(n\_layers-1)])
x has shape [seq_len, batch_size, input_size] and state has shape [batch_size, hidden_size] .
127defforward(self,x:torch.Tensor,state:Optional[torch.Tensor]=None):
132time\_steps,batch\_size=x.shape[:2]
Initialize the state if None
135ifstateisNone:136s=[x.new\_zeros(batch\_size,self.hidden\_size)for\_inrange(self.n\_layers)]137else:
Reverse stack the state to get the state of each layer
📝 You can just work with the tensor itself but this is easier to debug
141s=torch.unbind(state)
Array to collect the outputs of the final layer at each time step.
144out=[]
Run through the network for each time step
147fortinrange(time\_steps):
Input to the first layer is the input itself
149inp=x[t]
Loop through the layers
151forlayerinrange(self.n\_layers):
Get the state of the layer
153s[layer]=self.cells[layer](inp,s[layer])
Input to the next layer is the state of this layer
155inp=s[layer]
Collect the output of the final layer
157out.append(s[-1])
Stack the outputs and states
160out=torch.stack(out)161s=torch.stack(s)162163returnout,s