docs/transformers/primer_ez/variations.html
We tried some variations to see which changes in Primer EZ has most benefits.
12importtorch13fromtorchimportnn1415fromlabml\_nn.transformersimportMultiHeadAttention
We share the same kernel across all channels.
18classSpatialDepthWiseSharedConvolution(nn.Module):
25def\_\_init\_\_(self,kernel\_size:int=3):26super().\_\_init\_\_()27self.kernel\_size=kernel\_size
We use PyTorch's Conv1d module. We add padding to both sides and later crop the right most kernel_size - 1 results
32self.conv=nn.Conv1d(in\_channels=1,out\_channels=1,kernel\_size=(kernel\_size,),padding=(kernel\_size-1,))
x has shape [seq_len, batch_size, heads, d_k]
34defforward(self,x:torch.Tensor):
Get the shape
40seq\_len,batch\_size,heads,d\_k=x.shape
Permute to [batch_size, heads, d_k, seq_len]
42x=x.permute(1,2,3,0)
Change the shape to [batch_size * heads * d_k, seq_len]
44x=x.view(batch\_size\*heads\*d\_k,1,seq\_len)
1D convolution accepts input of the form [N, channels, sequence]
47x=self.conv(x)
Crop the right most kernel_size - 1 results since we padded both sides
49x=x[:,:,:-(self.kernel\_size-1)]
Reshape to [batch_size, heads, d_k, seq_len]
51x=x.view(batch\_size,heads,d\_k,seq\_len)
Permute to [seq_len, batch_size, heads, d_k]
53x=x.permute(3,0,1,2)
56returnx
We extend our original implementation of Multi-Head Attention and add the spatial depth-wise shared convolution to query, key and value projections.
59classMultiDSharedConvHeadAttention(MultiHeadAttention):
67def\_\_init\_\_(self,heads:int,d\_model:int,dropout\_prob:float=0.1):68super().\_\_init\_\_(heads,d\_model,dropout\_prob)
Multi-Head Attention will create query, key and value projection modules self.query , self.key , and self.value .
We combine a spatial depth-wise shared convolution layer to each of them and replace self.query , self.key , and self.value .
75self.query=nn.Sequential(self.query,SpatialDepthWiseSharedConvolution())76self.key=nn.Sequential(self.key,SpatialDepthWiseSharedConvolution())77self.value=nn.Sequential(self.value,SpatialDepthWiseSharedConvolution())
80classSpatialDepthWisePerHeadConvolution(nn.Module):
heads is the number of headsd_k is the number of channels in each head85def\_\_init\_\_(self,heads:int,d\_k:int,kernel\_size:int=3):
90super().\_\_init\_\_()91self.kernel\_size=kernel\_size
We use PyTorch's Conv1d module. We set the number of groups to be equal to the number of channels from each head so that it does a separate convolution (with different kernels) for each channel and head. We add padding to both sides and later crop the right most kernel_size - 1 results
97self.conv=nn.Conv1d(in\_channels=d\_k\*heads,out\_channels=d\_k\*heads,98kernel\_size=(kernel\_size,),padding=(kernel\_size-1,),groups=d\_k\*heads)
x has shape [seq_len, batch_size, heads, d_k]
100defforward(self,x:torch.Tensor):
Get the shape
106seq\_len,batch\_size,heads,d\_k=x.shape
Permute to [batch_size, heads, d_k, seq_len]
108x=x.permute(1,2,3,0)
Change the shape to [batch_size heads * d_k, seq_len]
110x=x.view(batch\_size,heads\*d\_k,seq\_len)
1D convolution accepts input of the form [N, channels, sequence]
113x=self.conv(x)
Crop the right most kernel_size - 1 results since we padded both sides
115x=x[:,:,:-(self.kernel\_size-1)]
Reshape to [batch_size, heads, d_k, seq_len]
117x=x.view(batch\_size,heads,d\_k,seq\_len)
Permute to [seq_len, batch_size, heads, d_k]
119x=x.permute(3,0,1,2)
122returnx
We extend our original implementation of Multi-Head Attention and add the spatial depth-wise convolution to query, key and value projections.
125classMultiDPHConvHeadAttention(MultiHeadAttention):
133def\_\_init\_\_(self,heads:int,d\_model:int,dropout\_prob:float=0.1):134super().\_\_init\_\_(heads,d\_model,dropout\_prob)
Multi-Head Attention will create query, key and value projection modules self.query , self.key , and self.value .
We combine a spatial per-head depth-wise convolution layer to each of them and replace self.query , self.key , and self.value .
141self.query=nn.Sequential(self.query,SpatialDepthWisePerHeadConvolution(heads,self.d\_k))142self.key=nn.Sequential(self.key,SpatialDepthWisePerHeadConvolution(heads,self.d\_k))143self.value=nn.Sequential(self.value,SpatialDepthWisePerHeadConvolution(heads,self.d\_k))