docs/lstm/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/lstm/ init.py)
This is a PyTorch implementation of Long Short-Term Memory.
12fromtypingimportOptional,Tuple1314importtorch15fromtorchimportnn
LSTM Cell computes c, and h. c is like the long-term memory, and h is like the short term memory. We use the input x and h to update the long term memory. In the update, some features of c are cleared with a forget gate f, and some features i are added through a gate g.
The new short term memory is the tanh of the long-term memory multiplied by the output gate o.
Note that the cell doesn't look at long term memory c when doing the update. It only modifies it. Also c never goes through a linear transformation. This is what solves vanishing and exploding gradients.
Here's the update rule.
ctht=σ(ft)⊙ct−1+σ(it)⊙tanh(gt)=σ(ot)⊙tanh(ct)
⊙ stands for element-wise multiplication.
Intermediate values and gates are computed as linear transformations of the hidden state and input.
itftgtot=linxi(xt)+linhi(ht−1)=linxf(xt)+linhf(ht−1)=linxg(xt)+linhg(ht−1)=linxo(xt)+linho(ht−1)
19classLSTMCell(nn.Module):
56def\_\_init\_\_(self,input\_size:int,hidden\_size:int,layer\_norm:bool=False):57super().\_\_init\_\_()
These are the linear layer to transform the input and hidden vectors. One of them doesn't need a bias since we add the transformations.
This combines linxi, linxf, linxg, and linxo transformations.
63self.hidden\_lin=nn.Linear(hidden\_size,4\*hidden\_size)
This combines linhi, linhf, linhg, and linho transformations.
65self.input\_lin=nn.Linear(input\_size,4\*hidden\_size,bias=False)
Whether to apply layer normalizations.
Applying layer normalization gives better results. i, f, g and o embeddings are normalized and ct is normalized in ht=ot⊙tanh(LN(ct))
72iflayer\_norm:73self.layer\_norm=nn.ModuleList([nn.LayerNorm(hidden\_size)for\_inrange(4)])74self.layer\_norm\_c=nn.LayerNorm(hidden\_size)75else:76self.layer\_norm=nn.ModuleList([nn.Identity()for\_inrange(4)])77self.layer\_norm\_c=nn.Identity()
79defforward(self,x:torch.Tensor,h:torch.Tensor,c:torch.Tensor):
We compute the linear transformations for it, ft, gt and ot using the same linear layers.
82ifgo=self.hidden\_lin(h)+self.input\_lin(x)
Each layer produces an output of 4 times the hidden_size and we split them
84ifgo=ifgo.chunk(4,dim=-1)
Apply layer normalization (not in original paper, but gives better results)
87ifgo=[self.layer\_norm[i](ifgo[i])foriinrange(4)]
it,ft,gt,ot
90i,f,g,o=ifgo
ct=σ(ft)⊙ct−1+σ(it)⊙tanh(gt)
93c\_next=torch.sigmoid(f)\*c+torch.sigmoid(i)\*torch.tanh(g)
ht=σ(ot)⊙tanh(ct) Optionally, apply layer norm to ct
97h\_next=torch.sigmoid(o)\*torch.tanh(self.layer\_norm\_c(c\_next))9899returnh\_next,c\_next
102classLSTM(nn.Module):
Create a network of n_layers of LSTM.
107def\_\_init\_\_(self,input\_size:int,hidden\_size:int,n\_layers:int):
112super().\_\_init\_\_()113self.n\_layers=n\_layers114self.hidden\_size=hidden\_size
Create cells for each layer. Note that only the first layer gets the input directly. Rest of the layers get the input from the layer below
117self.cells=nn.ModuleList([LSTMCell(input\_size,hidden\_size)]+118[LSTMCell(hidden\_size,hidden\_size)for\_inrange(n\_layers-1)])
x has shape [n_steps, batch_size, input_size] and state is a tuple of h and c, each with a shape of [batch_size, hidden_size] .
120defforward(self,x:torch.Tensor,state:Optional[Tuple[torch.Tensor,torch.Tensor]]=None):
125n\_steps,batch\_size=x.shape[:2]
Initialize the state if None
128ifstateisNone:129h=[x.new\_zeros(batch\_size,self.hidden\_size)for\_inrange(self.n\_layers)]130c=[x.new\_zeros(batch\_size,self.hidden\_size)for\_inrange(self.n\_layers)]131else:132(h,c)=state
Reverse stack the tensors to get the states of each layer
📝 You can just work with the tensor itself but this is easier to debug
136h,c=list(torch.unbind(h)),list(torch.unbind(c))
Array to collect the outputs of the final layer at each time step.
139out=[]140fortinrange(n\_steps):
Input to the first layer is the input itself
142inp=x[t]
Loop through the layers
144forlayerinrange(self.n\_layers):
Get the state of the layer
146h[layer],c[layer]=self.cells[layer](inp,h[layer],c[layer])
Input to the next layer is the state of this layer
148inp=h[layer]
Collect the output h of the final layer
150out.append(h[-1])
Stack the outputs and states
153out=torch.stack(out)154h=torch.stack(h)155c=torch.stack(c)156157returnout,(h,c)