docs/transformers/gmlp/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/gmlp/ init.py)
This is a PyTorch implementation of the paper Pay Attention to MLPs.
This paper introduces a Multilayer Perceptron (MLP) based architecture with gating, which they name gMLP. It consists of a stack of L gMLP blocks.
Here is the training code for a gMLP model based autoregressive model.
19fromtypingimportOptional2021importtorch22fromtorchimportnn
Each block does the following transformations to input embeddings X∈Rn×d where n is the sequence length and d is the dimensionality of the embeddings:
ZZY=σ(XU)=s(Z)=ZV
where V and U are learnable projection weights. s(⋅) is the Spacial Gating Unit defined below. Output dimensionality of s(⋅) will be half of Z. σ is an activation function such as GeLU.
25classGMLPBlock(nn.Module):
d_model is the dimensionality (d) of Xd_ffn is the dimensionality of Zseq_len is the length of the token sequence (n)46def\_\_init\_\_(self,d\_model:int,d\_ffn:int,seq\_len:int):
52super().\_\_init\_\_()
Normalization layer fro Pre-Norm
54self.norm=nn.LayerNorm([d\_model])
Activation function σ
56self.activation=nn.GELU()
Projection layer for Z=σ(XU)
58self.proj1=nn.Linear(d\_model,d\_ffn)
Spacial Gating Unit s(⋅)
60self.sgu=SpacialGatingUnit(d\_ffn,seq\_len)
Projection layer for Y=Z~V
62self.proj2=nn.Linear(d\_ffn//2,d\_model)
Embedding size (required by Encoder. We use the encoder module from transformer architecture and plug gMLP block as a replacement for the Transformer Layer.
66self.size=d\_model
x is the input embedding tensor X of shape [seq_len, batch_size, d_model]mask is a boolean mask of shape [seq_len, seq_len, 1] that controls the visibility of tokens among each other.68defforward(self,\*,x:torch.Tensor,mask:Optional[torch.Tensor]=None):
Keep a copy for shortcut connection
75shortcut=x
Normalize X
77x=self.norm(x)
Projection and activation Z=σ(XU)
79z=self.activation(self.proj1(x))
Spacial Gating Unit Z~=s(Z)
81z=self.sgu(z,mask)
Final projection Y=Z~V
83z=self.proj2(z)
Add the shortcut connection
86returnz+shortcut
s(Z)=Z1⊙fW,b(Z2)
where fW,b(Z)=WZ+b is a linear transformation along the sequence dimension, and ⊙ is element-wise multiplication. Z is split into to parts of equal size Z1 and Z2 along the channel dimension (embedding dimension).
89classSpacialGatingUnit(nn.Module):
d_z is the dimensionality of Zseq_len is the sequence length99def\_\_init\_\_(self,d\_z:int,seq\_len:int):
104super().\_\_init\_\_()
Normalization layer before applying fW,b(⋅)
106self.norm=nn.LayerNorm([d\_z//2])
Weight W in fW,b(⋅).
The paper notes that it's important to initialize weights to small values and the bias to 1, so that during the initial training s(⋅) is close to identity (apart from the split).
111self.weight=nn.Parameter(torch.zeros(seq\_len,seq\_len).uniform\_(-0.01,0.01),requires\_grad=True)
Weight b in fW,b(⋅)
The paper notes that it's important to initialize bias to 1.
115self.bias=nn.Parameter(torch.ones(seq\_len),requires\_grad=True)
z is the input Z of shape [seq_len, batch_size, d_z]mask is is a boolean mask of shape [seq_len, seq_len, 1] that controls the visibility of tokens among each other. The last dimension of size 1 is the batch, which we have in other transformer implementations and was left for compatibility.117defforward(self,z:torch.Tensor,mask:Optional[torch.Tensor]=None):
Get sequence length
126seq\_len=z.shape[0]
Split Z into Z1 and Z2
128z1,z2=torch.chunk(z,2,dim=-1)
Check mask
131ifmaskisnotNone:
mask has shape [seq_len_q, seq_len_k, batch_size] . The batch dimension should be of size 1 because this implementation supports only same mask for all samples in the batch.
135assertmask.shape[0]==1ormask.shape[0]==seq\_len136assertmask.shape[1]==seq\_len
Here we only support the same mask for all samples
138assertmask.shape[2]==1
Remove the batch dimension
140mask=mask[:,:,0]
Normalize Z2 before fW,b(⋅)
143z2=self.norm(z2)
Get the weight matrix; truncate if larger than seq_len
145weight=self.weight[:seq\_len,:seq\_len]
Apply mask to the weights.
If Wi,j is 0 then fW,b(Z2)i will not get any information from token j.
150ifmaskisnotNone:151weight=weight\*mask
fW,b(Z2)=WZ2+b
154z2=torch.einsum('ij,jbd-\>ibd',weight,z2)+self.bias[:seq\_len,None,None]
Z1⊙fW,b(Z2)
157returnz1\*z2