Back to Annotated Deep Learning Paper Implementations

Patches Are All You Need? (ConvMixer)

docs/conv_mixer/index.html

latest7.4 KB
Original Source

homeconv_mixer

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/conv_mixer/ init.py)

#

Patches Are All You Need? (ConvMixer)

This is a PyTorch implementation of the paper Patches Are All You Need?.

ConvMixer is Similar to MLP-Mixer. MLP-Mixer separates mixing of spatial and channel dimensions, by applying an MLP across spatial dimension and then an MLP across the channel dimension (spatial MLP replaces the ViT attention and channel MLP is the FFN of ViT).

ConvMixer uses a 1×1 convolution for channel mixing and a depth-wise convolution for spatial mixing. Since it's a convolution instead of a full MLP across the space, it mixes only the nearby batches in contrast to ViT or MLP-Mixer. Also, the MLP-mixer uses MLPs of two layers for each mixing and ConvMixer uses a single layer for each mixing.

The paper recommends removing the residual connection across the channel mixing (point-wise convolution) and having only a residual connection over the spatial mixing (depth-wise convolution). They also use Batch normalization instead of Layer normalization.

Here's an experiment that trains ConvMixer on CIFAR-10.

36importtorch37fromtorchimportnn3839fromlabml\_nn.utilsimportclone\_module\_list

#

ConvMixer layer

This is a single ConvMixer layer. The model will have a series of these.

42classConvMixerLayer(nn.Module):

#

  • d_model is the number of channels in patch embeddings, h
  • kernel_size is the size of the kernel of spatial convolution, k
51def\_\_init\_\_(self,d\_model:int,kernel\_size:int):

#

56super().\_\_init\_\_()

#

Depth-wise convolution is separate convolution for each channel. We do this with a convolution layer with the number of groups equal to the number of channels. So that each channel is it's own group.

60self.depth\_wise\_conv=nn.Conv2d(d\_model,d\_model,61kernel\_size=kernel\_size,62groups=d\_model,63padding=(kernel\_size-1)//2)

#

Activation after depth-wise convolution

65self.act1=nn.GELU()

#

Normalization after depth-wise convolution

67self.norm1=nn.BatchNorm2d(d\_model)

#

Point-wise convolution is a 1×1 convolution. i.e. a linear transformation of patch embeddings

71self.point\_wise\_conv=nn.Conv2d(d\_model,d\_model,kernel\_size=1)

#

Activation after point-wise convolution

73self.act2=nn.GELU()

#

Normalization after point-wise convolution

75self.norm2=nn.BatchNorm2d(d\_model)

#

77defforward(self,x:torch.Tensor):

#

For the residual connection around the depth-wise convolution

79residual=x

#

Depth-wise convolution, activation and normalization

82x=self.depth\_wise\_conv(x)83x=self.act1(x)84x=self.norm1(x)

#

Add residual connection

87x+=residual

#

Point-wise convolution, activation and normalization

90x=self.point\_wise\_conv(x)91x=self.act2(x)92x=self.norm2(x)

#

95returnx

#

Get patch embeddings

This splits the image into patches of size p×p and gives an embedding for each patch.

98classPatchEmbeddings(nn.Module):

#

  • d_model is the number of channels in patch embeddings h
  • patch_size is the size of the patch, p
  • in_channels is the number of channels in the input image (3 for rgb)
107def\_\_init\_\_(self,d\_model:int,patch\_size:int,in\_channels:int):

#

113super().\_\_init\_\_()

#

We create a convolution layer with a kernel size and and stride length equal to patch size. This is equivalent to splitting the image into patches and doing a linear transformation on each patch.

118self.conv=nn.Conv2d(in\_channels,d\_model,kernel\_size=patch\_size,stride=patch\_size)

#

Activation function

120self.act=nn.GELU()

#

Batch normalization

122self.norm=nn.BatchNorm2d(d\_model)

#

  • x is the input image of shape [batch_size, channels, height, width]
124defforward(self,x:torch.Tensor):

#

Apply convolution layer

129x=self.conv(x)

#

Activation and normalization

131x=self.act(x)132x=self.norm(x)

#

135returnx

#

Classification Head

They do average pooling (taking the mean of all patch embeddings) and a final linear transformation to predict the log-probabilities of the image classes.

138classClassificationHead(nn.Module):

#

  • d_model is the number of channels in patch embeddings, h
  • n_classes is the number of classes in the classification task
148def\_\_init\_\_(self,d\_model:int,n\_classes:int):

#

153super().\_\_init\_\_()

#

Average Pool

155self.pool=nn.AdaptiveAvgPool2d((1,1))

#

Linear layer

157self.linear=nn.Linear(d\_model,n\_classes)

#

159defforward(self,x:torch.Tensor):

#

Average pooling

161x=self.pool(x)

#

Get the embedding, x will have shape [batch_size, d_model, 1, 1]

163x=x[:,:,0,0]

#

Linear layer

165x=self.linear(x)

#

168returnx

#

ConvMixer

This combines the patch embeddings block, a number of ConvMixer layers and a classification head.

171classConvMixer(nn.Module):

#

178def\_\_init\_\_(self,conv\_mixer\_layer:ConvMixerLayer,n\_layers:int,179patch\_emb:PatchEmbeddings,180classification:ClassificationHead):

#

188super().\_\_init\_\_()

#

Patch embeddings

190self.patch\_emb=patch\_emb

#

Classification head

192self.classification=classification

#

Make copies of the ConvMixer layer

194self.conv\_mixer\_layers=clone\_module\_list(conv\_mixer\_layer,n\_layers)

#

  • x is the input image of shape [batch_size, channels, height, width]
196defforward(self,x:torch.Tensor):

#

Get patch embeddings. This gives a tensor of shape [batch_size, d_model, height / patch_size, width / patch_size] .

201x=self.patch\_emb(x)

#

Pass through ConvMixer layers

204forlayerinself.conv\_mixer\_layers:205x=layer(x)

#

Classification head, to get logits

208x=self.classification(x)

#

211returnx

labml.ai