Back to Annotated Deep Learning Paper Implementations

Long Short-Term Memory (LSTM)

docs/lstm/index.html

latest5.6 KB
Original Source

homelstm

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/lstm/ init.py)

#

Long Short-Term Memory (LSTM)

This is a PyTorch implementation of Long Short-Term Memory.

12fromtypingimportOptional,Tuple1314importtorch15fromtorchimportnn

#

Long Short-Term Memory Cell

LSTM Cell computes c, and h. c is like the long-term memory, and h is like the short term memory. We use the input x and h to update the long term memory. In the update, some features of c are cleared with a forget gate f, and some features i are added through a gate g.

The new short term memory is the tanh of the long-term memory multiplied by the output gate o.

Note that the cell doesn't look at long term memory c when doing the update. It only modifies it. Also c never goes through a linear transformation. This is what solves vanishing and exploding gradients.

Here's the update rule.

ct​ht​​=σ(ft​)⊙ct−1​+σ(it​)⊙tanh(gt​)=σ(ot​)⊙tanh(ct​)​

⊙ stands for element-wise multiplication.

Intermediate values and gates are computed as linear transformations of the hidden state and input.

it​ft​gt​ot​​=linxi​(xt​)+linhi​(ht−1​)=linxf​(xt​)+linhf​(ht−1​)=linxg​(xt​)+linhg​(ht−1​)=linxo​(xt​)+linho​(ht−1​)​

19classLSTMCell(nn.Module):

#

56def\_\_init\_\_(self,input\_size:int,hidden\_size:int,layer\_norm:bool=False):57super().\_\_init\_\_()

#

These are the linear layer to transform the input and hidden vectors. One of them doesn't need a bias since we add the transformations.

#

This combines linxi​, linxf​, linxg​, and linxo​ transformations.

63self.hidden\_lin=nn.Linear(hidden\_size,4\*hidden\_size)

#

This combines linhi​, linhf​, linhg​, and linho​ transformations.

65self.input\_lin=nn.Linear(input\_size,4\*hidden\_size,bias=False)

#

Whether to apply layer normalizations.

Applying layer normalization gives better results. i, f, g and o embeddings are normalized and ct​ is normalized in ht​=ot​⊙tanh(LN(ct​))

72iflayer\_norm:73self.layer\_norm=nn.ModuleList([nn.LayerNorm(hidden\_size)for\_inrange(4)])74self.layer\_norm\_c=nn.LayerNorm(hidden\_size)75else:76self.layer\_norm=nn.ModuleList([nn.Identity()for\_inrange(4)])77self.layer\_norm\_c=nn.Identity()

#

79defforward(self,x:torch.Tensor,h:torch.Tensor,c:torch.Tensor):

#

We compute the linear transformations for it​, ft​, gt​ and ot​ using the same linear layers.

82ifgo=self.hidden\_lin(h)+self.input\_lin(x)

#

Each layer produces an output of 4 times the hidden_size and we split them

84ifgo=ifgo.chunk(4,dim=-1)

#

Apply layer normalization (not in original paper, but gives better results)

87ifgo=[self.layer\_norm[i](ifgo[i])foriinrange(4)]

#

it​,ft​,gt​,ot​

90i,f,g,o=ifgo

#

ct​=σ(ft​)⊙ct−1​+σ(it​)⊙tanh(gt​)

93c\_next=torch.sigmoid(f)\*c+torch.sigmoid(i)\*torch.tanh(g)

#

ht​=σ(ot​)⊙tanh(ct​) Optionally, apply layer norm to ct​

97h\_next=torch.sigmoid(o)\*torch.tanh(self.layer\_norm\_c(c\_next))9899returnh\_next,c\_next

#

Multilayer LSTM

102classLSTM(nn.Module):

#

Create a network of n_layers of LSTM.

107def\_\_init\_\_(self,input\_size:int,hidden\_size:int,n\_layers:int):

#

112super().\_\_init\_\_()113self.n\_layers=n\_layers114self.hidden\_size=hidden\_size

#

Create cells for each layer. Note that only the first layer gets the input directly. Rest of the layers get the input from the layer below

117self.cells=nn.ModuleList([LSTMCell(input\_size,hidden\_size)]+118[LSTMCell(hidden\_size,hidden\_size)for\_inrange(n\_layers-1)])

#

x has shape [n_steps, batch_size, input_size] and state is a tuple of h and c, each with a shape of [batch_size, hidden_size] .

120defforward(self,x:torch.Tensor,state:Optional[Tuple[torch.Tensor,torch.Tensor]]=None):

#

125n\_steps,batch\_size=x.shape[:2]

#

Initialize the state if None

128ifstateisNone:129h=[x.new\_zeros(batch\_size,self.hidden\_size)for\_inrange(self.n\_layers)]130c=[x.new\_zeros(batch\_size,self.hidden\_size)for\_inrange(self.n\_layers)]131else:132(h,c)=state

#

Reverse stack the tensors to get the states of each layer

📝 You can just work with the tensor itself but this is easier to debug

136h,c=list(torch.unbind(h)),list(torch.unbind(c))

#

Array to collect the outputs of the final layer at each time step.

139out=[]140fortinrange(n\_steps):

#

Input to the first layer is the input itself

142inp=x[t]

#

Loop through the layers

144forlayerinrange(self.n\_layers):

#

Get the state of the layer

146h[layer],c[layer]=self.cells[layer](inp,h[layer],c[layer])

#

Input to the next layer is the state of this layer

148inp=h[layer]

#

Collect the output h of the final layer

150out.append(h[-1])

#

Stack the outputs and states

153out=torch.stack(out)154h=torch.stack(h)155c=torch.stack(c)156157returnout,(h,c)

labml.ai