Back to Annotated Deep Learning Paper Implementations

Primer EZ Variations

docs/transformers/primer_ez/variations.html

latest5.6 KB
Original Source

hometransformersprimer_ez

View code on Github

#

Primer EZ Variations

We tried some variations to see which changes in Primer EZ has most benefits.

12importtorch13fromtorchimportnn1415fromlabml\_nn.transformersimportMultiHeadAttention

#

Spatial Depth Wise Shared Convolution

We share the same kernel across all channels.

18classSpatialDepthWiseSharedConvolution(nn.Module):

#

25def\_\_init\_\_(self,kernel\_size:int=3):26super().\_\_init\_\_()27self.kernel\_size=kernel\_size

#

We use PyTorch's Conv1d module. We add padding to both sides and later crop the right most kernel_size - 1 results

32self.conv=nn.Conv1d(in\_channels=1,out\_channels=1,kernel\_size=(kernel\_size,),padding=(kernel\_size-1,))

#

x has shape [seq_len, batch_size, heads, d_k]

34defforward(self,x:torch.Tensor):

#

Get the shape

40seq\_len,batch\_size,heads,d\_k=x.shape

#

Permute to [batch_size, heads, d_k, seq_len]

42x=x.permute(1,2,3,0)

#

Change the shape to [batch_size * heads * d_k, seq_len]

44x=x.view(batch\_size\*heads\*d\_k,1,seq\_len)

#

1D convolution accepts input of the form [N, channels, sequence]

47x=self.conv(x)

#

Crop the right most kernel_size - 1 results since we padded both sides

49x=x[:,:,:-(self.kernel\_size-1)]

#

Reshape to [batch_size, heads, d_k, seq_len]

51x=x.view(batch\_size,heads,d\_k,seq\_len)

#

Permute to [seq_len, batch_size, heads, d_k]

53x=x.permute(3,0,1,2)

#

56returnx

#

Multi-Depth-wise-Shared-Conv-Head Attention

We extend our original implementation of Multi-Head Attention and add the spatial depth-wise shared convolution to query, key and value projections.

59classMultiDSharedConvHeadAttention(MultiHeadAttention):

#

67def\_\_init\_\_(self,heads:int,d\_model:int,dropout\_prob:float=0.1):68super().\_\_init\_\_(heads,d\_model,dropout\_prob)

#

Multi-Head Attention will create query, key and value projection modules self.query , self.key , and self.value .

We combine a spatial depth-wise shared convolution layer to each of them and replace self.query , self.key , and self.value .

75self.query=nn.Sequential(self.query,SpatialDepthWiseSharedConvolution())76self.key=nn.Sequential(self.key,SpatialDepthWiseSharedConvolution())77self.value=nn.Sequential(self.value,SpatialDepthWiseSharedConvolution())

#

Spatial Depth Wise Per Head Convolution

80classSpatialDepthWisePerHeadConvolution(nn.Module):

#

  • heads is the number of heads
  • d_k is the number of channels in each head
85def\_\_init\_\_(self,heads:int,d\_k:int,kernel\_size:int=3):

#

90super().\_\_init\_\_()91self.kernel\_size=kernel\_size

#

We use PyTorch's Conv1d module. We set the number of groups to be equal to the number of channels from each head so that it does a separate convolution (with different kernels) for each channel and head. We add padding to both sides and later crop the right most kernel_size - 1 results

97self.conv=nn.Conv1d(in\_channels=d\_k\*heads,out\_channels=d\_k\*heads,98kernel\_size=(kernel\_size,),padding=(kernel\_size-1,),groups=d\_k\*heads)

#

x has shape [seq_len, batch_size, heads, d_k]

100defforward(self,x:torch.Tensor):

#

Get the shape

106seq\_len,batch\_size,heads,d\_k=x.shape

#

Permute to [batch_size, heads, d_k, seq_len]

108x=x.permute(1,2,3,0)

#

Change the shape to [batch_size heads * d_k, seq_len]

110x=x.view(batch\_size,heads\*d\_k,seq\_len)

#

1D convolution accepts input of the form [N, channels, sequence]

113x=self.conv(x)

#

Crop the right most kernel_size - 1 results since we padded both sides

115x=x[:,:,:-(self.kernel\_size-1)]

#

Reshape to [batch_size, heads, d_k, seq_len]

117x=x.view(batch\_size,heads,d\_k,seq\_len)

#

Permute to [seq_len, batch_size, heads, d_k]

119x=x.permute(3,0,1,2)

#

122returnx

#

Multi-per-Head-Depth-wise-Conv-Head Attention

We extend our original implementation of Multi-Head Attention and add the spatial depth-wise convolution to query, key and value projections.

125classMultiDPHConvHeadAttention(MultiHeadAttention):

#

133def\_\_init\_\_(self,heads:int,d\_model:int,dropout\_prob:float=0.1):134super().\_\_init\_\_(heads,d\_model,dropout\_prob)

#

Multi-Head Attention will create query, key and value projection modules self.query , self.key , and self.value .

We combine a spatial per-head depth-wise convolution layer to each of them and replace self.query , self.key , and self.value .

141self.query=nn.Sequential(self.query,SpatialDepthWisePerHeadConvolution(heads,self.d\_k))142self.key=nn.Sequential(self.key,SpatialDepthWisePerHeadConvolution(heads,self.d\_k))143self.value=nn.Sequential(self.value,SpatialDepthWisePerHeadConvolution(heads,self.d\_k))

labml.ai