Back to Annotated Deep Learning Paper Implementations

efficient.py

docs/transformers/primer_ez/efficient.html

latest2.3 KB
Original Source

hometransformersprimer_ez

View code on Github

#

1importmath23importtorch4fromtorchimportnn56fromlabml\_nn.transformersimportMultiHeadAttention

#

Spatial Depth Wise Convolution

This is actually slower

9classSpatialDepthWiseConvolution(nn.Module):

#

  • d_k is the number of channels in each head
16def\_\_init\_\_(self,d\_k:int,kernel\_size:int=3):

#

20super().\_\_init\_\_()21self.kernel\_size=kernel\_size

#

We use PyTorch's Conv1d module. We set the number of groups to be equal to the number of channels so that it does a separate convolution (with different kernels) for each channel. We add padding to both sides and later crop the right most kernel_size - 1 results

26rng=1/math.sqrt(kernel\_size)27self.kernels=nn.Parameter(torch.zeros((kernel\_size,d\_k)).uniform\_(-rng,rng))

#

x has shape [seq_len, batch_size, heads, d_k]

29defforward(self,x:torch.Tensor):

#

34res=x\*self.kernels[0].view(1,1,1,-1)3536foriinrange(1,len(self.kernels)):37res[i:]+=x[:-i]\*self.kernels[i].view(1,1,1,-1)3839returnres

#

Multi-DConv-Head Attention (MDHA)

We extend our original implementation of Multi-Head Attention and add the spatial depth-wise convolution to query, key and value projections.

42classMultiDConvHeadAttention(MultiHeadAttention):

#

50def\_\_init\_\_(self,heads:int,d\_model:int,dropout\_prob:float=0.1):51super().\_\_init\_\_(heads,d\_model,dropout\_prob)

#

Multi-Head Attention will create query, key and value projection modules self.query , self.key , and self.value .

We combine a spatial depth-wise convolution layer to each of them and replace self.query , self.key , and self.value .

58self.query=nn.Sequential(self.query,SpatialDepthWiseConvolution(self.d\_k))59self.key=nn.Sequential(self.key,SpatialDepthWiseConvolution(self.d\_k))60self.value=nn.Sequential(self.value,SpatialDepthWiseConvolution(self.d\_k))

labml.ai