Back to Annotated Deep Learning Paper Implementations

Pay Attention to MLPs (gMLP)

docs/transformers/gmlp/index.html

latest5.7 KB
Original Source

hometransformersgmlp

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/gmlp/ init.py)

#

Pay Attention to MLPs (gMLP)

This is a PyTorch implementation of the paper Pay Attention to MLPs.

This paper introduces a Multilayer Perceptron (MLP) based architecture with gating, which they name gMLP. It consists of a stack of L gMLP blocks.

Here is the training code for a gMLP model based autoregressive model.

19fromtypingimportOptional2021importtorch22fromtorchimportnn

#

gMLP Block

Each block does the following transformations to input embeddings X∈Rn×d where n is the sequence length and d is the dimensionality of the embeddings:

ZZY​=σ(XU)=s(Z)=ZV​

where V and U are learnable projection weights. s(⋅) is the Spacial Gating Unit defined below. Output dimensionality of s(⋅) will be half of Z. σ is an activation function such as GeLU.

25classGMLPBlock(nn.Module):

#

  • d_model is the dimensionality (d) of X
  • d_ffn is the dimensionality of Z
  • seq_len is the length of the token sequence (n)
46def\_\_init\_\_(self,d\_model:int,d\_ffn:int,seq\_len:int):

#

52super().\_\_init\_\_()

#

Normalization layer fro Pre-Norm

54self.norm=nn.LayerNorm([d\_model])

#

Activation function σ

56self.activation=nn.GELU()

#

Projection layer for Z=σ(XU)

58self.proj1=nn.Linear(d\_model,d\_ffn)

#

Spacial Gating Unit s(⋅)

60self.sgu=SpacialGatingUnit(d\_ffn,seq\_len)

#

Projection layer for Y=Z~V

62self.proj2=nn.Linear(d\_ffn//2,d\_model)

#

Embedding size (required by Encoder. We use the encoder module from transformer architecture and plug gMLP block as a replacement for the Transformer Layer.

66self.size=d\_model

#

  • x is the input embedding tensor X of shape [seq_len, batch_size, d_model]
  • mask is a boolean mask of shape [seq_len, seq_len, 1] that controls the visibility of tokens among each other.
68defforward(self,\*,x:torch.Tensor,mask:Optional[torch.Tensor]=None):

#

Keep a copy for shortcut connection

75shortcut=x

#

Normalize X

77x=self.norm(x)

#

Projection and activation Z=σ(XU)

79z=self.activation(self.proj1(x))

#

Spacial Gating Unit Z~=s(Z)

81z=self.sgu(z,mask)

#

Final projection Y=Z~V

83z=self.proj2(z)

#

Add the shortcut connection

86returnz+shortcut

#

Spatial Gating Unit

s(Z)=Z1​⊙fW,b​(Z2​)

where fW,b​(Z)=WZ+b is a linear transformation along the sequence dimension, and ⊙ is element-wise multiplication. Z is split into to parts of equal size Z1​ and Z2​ along the channel dimension (embedding dimension).

89classSpacialGatingUnit(nn.Module):

#

  • d_z is the dimensionality of Z
  • seq_len is the sequence length
99def\_\_init\_\_(self,d\_z:int,seq\_len:int):

#

104super().\_\_init\_\_()

#

Normalization layer before applying fW,b​(⋅)

106self.norm=nn.LayerNorm([d\_z//2])

#

Weight W in fW,b​(⋅).

The paper notes that it's important to initialize weights to small values and the bias to 1, so that during the initial training s(⋅) is close to identity (apart from the split).

111self.weight=nn.Parameter(torch.zeros(seq\_len,seq\_len).uniform\_(-0.01,0.01),requires\_grad=True)

#

Weight b in fW,b​(⋅)

The paper notes that it's important to initialize bias to 1.

115self.bias=nn.Parameter(torch.ones(seq\_len),requires\_grad=True)

#

  • z is the input Z of shape [seq_len, batch_size, d_z]
  • mask is is a boolean mask of shape [seq_len, seq_len, 1] that controls the visibility of tokens among each other. The last dimension of size 1 is the batch, which we have in other transformer implementations and was left for compatibility.
117defforward(self,z:torch.Tensor,mask:Optional[torch.Tensor]=None):

#

Get sequence length

126seq\_len=z.shape[0]

#

Split Z into Z1​ and Z2​

128z1,z2=torch.chunk(z,2,dim=-1)

#

Check mask

131ifmaskisnotNone:

#

mask has shape [seq_len_q, seq_len_k, batch_size] . The batch dimension should be of size 1 because this implementation supports only same mask for all samples in the batch.

135assertmask.shape[0]==1ormask.shape[0]==seq\_len136assertmask.shape[1]==seq\_len

#

Here we only support the same mask for all samples

138assertmask.shape[2]==1

#

Remove the batch dimension

140mask=mask[:,:,0]

#

Normalize Z2​ before fW,b​(⋅)

143z2=self.norm(z2)

#

Get the weight matrix; truncate if larger than seq_len

145weight=self.weight[:seq\_len,:seq\_len]

#

Apply mask to the weights.

If Wi,j​ is 0 then fW,b​(Z2​)i​ will not get any information from token j.

150ifmaskisnotNone:151weight=weight\*mask

#

fW,b​(Z2​)=WZ2​+b

154z2=torch.einsum('ij,jbd-\>ibd',weight,z2)+self.bias[:seq\_len,None,None]

#

Z1​⊙fW,b​(Z2​)

157returnz1\*z2

labml.ai