Back to Annotated Deep Learning Paper Implementations

Primer: Searching for Efficient Transformers for Language Modeling

docs/transformers/primer_ez/index.html

latest4.7 KB
Original Source

hometransformersprimer_ez

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/primer_ez/ init.py)

#

Primer: Searching for Efficient Transformers for Language Modeling

This is a PyTorch implementation of the paper Primer: Searching for Efficient Transformers for Language Modeling.

The authors do an evolutionary search for transformer architectures. They name the architecture found using the search Primer (PRIMitives searched transformER). Primer EZ is the architecture with the two most robust modifications in Primer compared to the original transformer. Primer EZ trains a lot faster than the vanilla transformer.

Squared ReLU

The most effective modification found by the search is using a square ReLU instead of ReLU in the position-wise feedforward module.

y=max(x,0)2

Multi-DConv-Head Attention (MDHA)

The next effective modification is a depth-wise 3×1 convolution after multi-head projection for queries, keys, and values. The convolution is along the sequence dimension and per channel (depth-wise). To be clear, if the number of channels in each head is dk​ the convolution will have 1×3 kernels for each of the dk​ channels.

Here is the experiment code, for Primer EZ.

38importtorch39fromtorchimportnn4041fromlabml\_nn.transformersimportMultiHeadAttention

#

Squared ReLU activation

y=max(x,0)2

Squared ReLU is used as the activation function in the position wise feedforward module.

44classSquaredReLU(nn.Module):

#

54def\_\_init\_\_(self):55super().\_\_init\_\_()56self.relu=nn.ReLU()

#

58defforward(self,x:torch.Tensor):

#

Apply ReLU

60x=self.relu(x)

#

Square it

62returnx\*x

#

Spatial Depth Wise Convolution

65classSpatialDepthWiseConvolution(nn.Module):

#

  • d_k is the number of channels in each head
70def\_\_init\_\_(self,d\_k:int,kernel\_size:int=3):

#

74super().\_\_init\_\_()75self.kernel\_size=kernel\_size

#

We use PyTorch's Conv1d module. We set the number of groups to be equal to the number of channels so that it does a separate convolution (with different kernels) for each channel. We add padding to both sides and later crop the right most kernel_size - 1 results

80self.conv=nn.Conv1d(in\_channels=d\_k,out\_channels=d\_k,81kernel\_size=(kernel\_size,),padding=(kernel\_size-1,),groups=d\_k)

#

x has shape [seq_len, batch_size, heads, d_k]

83defforward(self,x:torch.Tensor):

#

Get the shape

89seq\_len,batch\_size,heads,d\_k=x.shape

#

Permute to [batch_size, heads, d_k, seq_len]

91x=x.permute(1,2,3,0)

#

Change the shape to [batch_size * heads, d_k, seq_len]

93x=x.view(batch\_size\*heads,d\_k,seq\_len)

#

1D convolution accepts input of the form [N, channels, sequence]

96x=self.conv(x)

#

Crop the right most kernel_size - 1 results since we padded both sides

98x=x[:,:,:-(self.kernel\_size-1)]

#

Reshape to [batch_size, heads, d_k, seq_len]

100x=x.view(batch\_size,heads,d\_k,seq\_len)

#

Permute to [seq_len, batch_size, heads, d_k]

102x=x.permute(3,0,1,2)

#

105returnx

#

Multi-DConv-Head Attention (MDHA)

We extend our original implementation of Multi-Head Attention and add the spatial depth-wise convolution to query, key and value projections.

108classMultiDConvHeadAttention(MultiHeadAttention):

#

116def\_\_init\_\_(self,heads:int,d\_model:int,dropout\_prob:float=0.1):117super().\_\_init\_\_(heads,d\_model,dropout\_prob)

#

Multi-Head Attention will create query, key and value projection modules self.query , self.key , and self.value .

We combine a spatial depth-wise convolution layer to each of them and replace self.query , self.key , and self.value .

📝 We feel this cleaner implementation is easier to understand since it clearly shows the difference between this and vanilla transformer multi-head attention.

127self.query=nn.Sequential(self.query,SpatialDepthWiseConvolution(self.d\_k))128self.key=nn.Sequential(self.key,SpatialDepthWiseConvolution(self.d\_k))129self.value=nn.Sequential(self.value,SpatialDepthWiseConvolution(self.d\_k))

labml.ai