docs/transformers/mlp_mixer/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/mlp_mixer/ init.py)
This is a PyTorch implementation of the paper MLP-Mixer: An all-MLP Architecture for Vision.
This paper applies the model on vision tasks. The model is similar to a transformer with attention layer being replaced by a MLP that is applied across the patches (or tokens in case of a NLP task).
Our implementation of MLP Mixer is a drop in replacement for the self-attention layer in our transformer implementation. So it's just a couple of lines of code, transposing the tensor to apply the MLP across the sequence dimension.
Although the paper applied MLP Mixer on vision tasks, we tried it on a masked language model. Here is the experiment code.
27fromtypingimportOptional2829importtorch30fromtorchimportnn
This module is a drop-in replacement for self-attention layer. It transposes the input tensor before feeding it to the MLP and transposes back, so that the MLP is applied across the sequence dimension (across tokens or image patches) instead of the feature dimension.
33classMLPMixer(nn.Module):
ffn is the MLP module.43def\_\_init\_\_(self,mlp:nn.Module):
47super().\_\_init\_\_()48self.mlp=mlp
The normal attention module can be fed with different token embeddings for query,key, and value and a mask.
We follow the same function signature so that we can replace it directly.
For MLP mixing, x=query=key=value and masking is not possible. Shape of query (and key and value ) is [seq_len, batch_size, d_model] .
50defforward(self,query:torch.Tensor,key:torch.Tensor,value:torch.Tensor,mask:Optional[torch.Tensor]=None):
query,key, and value all should be the same
62assertqueryiskeyandkeyisvalue
MLP mixer doesn't support masking. i.e. all tokens will see all other token embeddings.
64assertmaskisNone
Assign to x for clarity
67x=query
Transpose so that the last dimension is the sequence dimension. New shape is [d_model, batch_size, seq_len]
71x=x.transpose(0,2)
Apply the MLP across tokens
73x=self.mlp(x)
Transpose back into original form
75x=x.transpose(0,2)
78returnx