docs/transformers/vit/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/vit/ init.py)
This is a PyTorch implementation of the paper An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale.
Vision transformer applies a pure transformer to images without any convolution layers. They split the image into patches and apply a transformer on patch embeddings. Patch embeddings are generated by applying a simple linear transformation to the flattened pixel values of the patch. Then a standard transformer encoder is fed with the patch embeddings, along with a classification token [CLS] . The encoding on the [CLS] token is used to classify the image with an MLP.
When feeding the transformer with the patches, learned positional embeddings are added to the patch embeddings, because the patch embeddings do not have any information about where that patch is from. The positional embeddings are a set of vectors for each patch location that get trained with gradient descent along with other parameters.
ViTs perform well when they are pre-trained on large datasets. The paper suggests pre-training them with an MLP classification head and then using a single linear layer when fine-tuning. The paper beats SOTA with a ViT pre-trained on a 300 million image dataset. They also use higher resolution images during inference while keeping the patch size the same. The positional embeddings for new patch locations are calculated by interpolating learning positional embeddings.
Here's an experiment that trains ViT on CIFAR-10. This doesn't do very well because it's trained on a small dataset. It's a simple experiment that anyone can run and play with ViTs.
43importtorch44fromtorchimportnn4546fromlabml\_nn.transformersimportTransformerLayer47fromlabml\_nn.utilsimportclone\_module\_list
The paper splits the image into patches of equal size and do a linear transformation on the flattened pixels for each patch.
We implement the same thing through a convolution layer, because it's simpler to implement.
50classPatchEmbeddings(nn.Module):
d_model is the transformer embeddings sizepatch_size is the size of the patchin_channels is the number of channels in the input image (3 for rgb)62def\_\_init\_\_(self,d\_model:int,patch\_size:int,in\_channels:int):
68super().\_\_init\_\_()
We create a convolution layer with a kernel size and and stride length equal to patch size. This is equivalent to splitting the image into patches and doing a linear transformation on each patch.
73self.conv=nn.Conv2d(in\_channels,d\_model,patch\_size,stride=patch\_size)
x is the input image of shape [batch_size, channels, height, width]75defforward(self,x:torch.Tensor):
Apply convolution layer
80x=self.conv(x)
Get the shape.
82bs,c,h,w=x.shape
Rearrange to shape [patches, batch_size, d_model]
84x=x.permute(2,3,0,1)85x=x.view(h\*w,bs,c)
Return the patch embeddings
88returnx
This adds learned positional embeddings to patch embeddings.
91classLearnedPositionalEmbeddings(nn.Module):
d_model is the transformer embeddings sizemax_len is the maximum number of patches100def\_\_init\_\_(self,d\_model:int,max\_len:int=5\_000):
105super().\_\_init\_\_()
Positional embeddings for each location
107self.positional\_encodings=nn.Parameter(torch.zeros(max\_len,1,d\_model),requires\_grad=True)
x is the patch embeddings of shape [patches, batch_size, d_model]109defforward(self,x:torch.Tensor):
Get the positional embeddings for the given patches
114pe=self.positional\_encodings[:x.shape[0]]
Add to patch embeddings and return
116returnx+pe
This is the two layer MLP head to classify the image based on [CLS] token embedding.
119classClassificationHead(nn.Module):
d_model is the transformer embedding sizen_hidden is the size of the hidden layern_classes is the number of classes in the classification task127def\_\_init\_\_(self,d\_model:int,n\_hidden:int,n\_classes:int):
133super().\_\_init\_\_()
First layer
135self.linear1=nn.Linear(d\_model,n\_hidden)
Activation
137self.act=nn.ReLU()
Second layer
139self.linear2=nn.Linear(n\_hidden,n\_classes)
x is the transformer encoding for [CLS] token141defforward(self,x:torch.Tensor):
First layer and activation
146x=self.act(self.linear1(x))
Second layer
148x=self.linear2(x)
151returnx
This combines the patch embeddings, positional embeddings, transformer and the classification head.
154classVisionTransformer(nn.Module):
transformer_layer is a copy of a single transformer layer. We make copies of it to make the transformer with n_layers .n_layers is the number of transformer layers.patch_emb is the patch embeddings layer.pos_emb is the positional embeddings layer.classification is the classification head.162def\_\_init\_\_(self,transformer\_layer:TransformerLayer,n\_layers:int,163patch\_emb:PatchEmbeddings,pos\_emb:LearnedPositionalEmbeddings,164classification:ClassificationHead):
173super().\_\_init\_\_()
Patch embeddings
175self.patch\_emb=patch\_emb176self.pos\_emb=pos\_emb
Classification head
178self.classification=classification
Make copies of the transformer layer
180self.transformer\_layers=clone\_module\_list(transformer\_layer,n\_layers)
[CLS] token embedding
183self.cls\_token\_emb=nn.Parameter(torch.randn(1,1,transformer\_layer.size),requires\_grad=True)
Final normalization layer
185self.ln=nn.LayerNorm([transformer\_layer.size])
x is the input image of shape [batch_size, channels, height, width]187defforward(self,x:torch.Tensor):
Get patch embeddings. This gives a tensor of shape [patches, batch_size, d_model]
192x=self.patch\_emb(x)
Concatenate the [CLS] token embeddings before feeding the transformer
194cls\_token\_emb=self.cls\_token\_emb.expand(-1,x.shape[1],-1)195x=torch.cat([cls\_token\_emb,x])
Add positional embeddings
197x=self.pos\_emb(x)
Pass through transformer layers with no attention masking
200forlayerinself.transformer\_layers:201x=layer(x=x,mask=None)
Get the transformer output of the [CLS] token (which is the first in the sequence).
204x=x[0]
Layer normalization
207x=self.ln(x)
Classification head, to get logits
210x=self.classification(x)
213returnx