docs/conv_mixer/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/conv_mixer/ init.py)
This is a PyTorch implementation of the paper Patches Are All You Need?.
ConvMixer is Similar to MLP-Mixer. MLP-Mixer separates mixing of spatial and channel dimensions, by applying an MLP across spatial dimension and then an MLP across the channel dimension (spatial MLP replaces the ViT attention and channel MLP is the FFN of ViT).
ConvMixer uses a 1×1 convolution for channel mixing and a depth-wise convolution for spatial mixing. Since it's a convolution instead of a full MLP across the space, it mixes only the nearby batches in contrast to ViT or MLP-Mixer. Also, the MLP-mixer uses MLPs of two layers for each mixing and ConvMixer uses a single layer for each mixing.
The paper recommends removing the residual connection across the channel mixing (point-wise convolution) and having only a residual connection over the spatial mixing (depth-wise convolution). They also use Batch normalization instead of Layer normalization.
Here's an experiment that trains ConvMixer on CIFAR-10.
36importtorch37fromtorchimportnn3839fromlabml\_nn.utilsimportclone\_module\_list
This is a single ConvMixer layer. The model will have a series of these.
42classConvMixerLayer(nn.Module):
d_model is the number of channels in patch embeddings, hkernel_size is the size of the kernel of spatial convolution, k51def\_\_init\_\_(self,d\_model:int,kernel\_size:int):
56super().\_\_init\_\_()
Depth-wise convolution is separate convolution for each channel. We do this with a convolution layer with the number of groups equal to the number of channels. So that each channel is it's own group.
60self.depth\_wise\_conv=nn.Conv2d(d\_model,d\_model,61kernel\_size=kernel\_size,62groups=d\_model,63padding=(kernel\_size-1)//2)
Activation after depth-wise convolution
65self.act1=nn.GELU()
Normalization after depth-wise convolution
67self.norm1=nn.BatchNorm2d(d\_model)
Point-wise convolution is a 1×1 convolution. i.e. a linear transformation of patch embeddings
71self.point\_wise\_conv=nn.Conv2d(d\_model,d\_model,kernel\_size=1)
Activation after point-wise convolution
73self.act2=nn.GELU()
Normalization after point-wise convolution
75self.norm2=nn.BatchNorm2d(d\_model)
77defforward(self,x:torch.Tensor):
For the residual connection around the depth-wise convolution
79residual=x
Depth-wise convolution, activation and normalization
82x=self.depth\_wise\_conv(x)83x=self.act1(x)84x=self.norm1(x)
Add residual connection
87x+=residual
Point-wise convolution, activation and normalization
90x=self.point\_wise\_conv(x)91x=self.act2(x)92x=self.norm2(x)
95returnx
This splits the image into patches of size p×p and gives an embedding for each patch.
98classPatchEmbeddings(nn.Module):
d_model is the number of channels in patch embeddings hpatch_size is the size of the patch, pin_channels is the number of channels in the input image (3 for rgb)107def\_\_init\_\_(self,d\_model:int,patch\_size:int,in\_channels:int):
113super().\_\_init\_\_()
We create a convolution layer with a kernel size and and stride length equal to patch size. This is equivalent to splitting the image into patches and doing a linear transformation on each patch.
118self.conv=nn.Conv2d(in\_channels,d\_model,kernel\_size=patch\_size,stride=patch\_size)
Activation function
120self.act=nn.GELU()
Batch normalization
122self.norm=nn.BatchNorm2d(d\_model)
x is the input image of shape [batch_size, channels, height, width]124defforward(self,x:torch.Tensor):
Apply convolution layer
129x=self.conv(x)
Activation and normalization
131x=self.act(x)132x=self.norm(x)
135returnx
They do average pooling (taking the mean of all patch embeddings) and a final linear transformation to predict the log-probabilities of the image classes.
138classClassificationHead(nn.Module):
d_model is the number of channels in patch embeddings, hn_classes is the number of classes in the classification task148def\_\_init\_\_(self,d\_model:int,n\_classes:int):
153super().\_\_init\_\_()
Average Pool
155self.pool=nn.AdaptiveAvgPool2d((1,1))
Linear layer
157self.linear=nn.Linear(d\_model,n\_classes)
159defforward(self,x:torch.Tensor):
Average pooling
161x=self.pool(x)
Get the embedding, x will have shape [batch_size, d_model, 1, 1]
163x=x[:,:,0,0]
Linear layer
165x=self.linear(x)
168returnx
This combines the patch embeddings block, a number of ConvMixer layers and a classification head.
171classConvMixer(nn.Module):
conv_mixer_layer is a copy of a single ConvMixer layer. We make copies of it to make ConvMixer with n_layers .n_layers is the number of ConvMixer layers (or depth), d.patch_emb is the patch embeddings layer.classification is the classification head.178def\_\_init\_\_(self,conv\_mixer\_layer:ConvMixerLayer,n\_layers:int,179patch\_emb:PatchEmbeddings,180classification:ClassificationHead):
188super().\_\_init\_\_()
Patch embeddings
190self.patch\_emb=patch\_emb
Classification head
192self.classification=classification
Make copies of the ConvMixer layer
194self.conv\_mixer\_layers=clone\_module\_list(conv\_mixer\_layer,n\_layers)
x is the input image of shape [batch_size, channels, height, width]196defforward(self,x:torch.Tensor):
Get patch embeddings. This gives a tensor of shape [batch_size, d_model, height / patch_size, width / patch_size] .
201x=self.patch\_emb(x)
Pass through ConvMixer layers
204forlayerinself.conv\_mixer\_layers:205x=layer(x)
Classification head, to get logits
208x=self.classification(x)
211returnx