Back to Annotated Deep Learning Paper Implementations

FNet: Mixing Tokens with Fourier Transforms

docs/transformers/fnet/index.html

latest3.1 KB
Original Source

hometransformersfnet

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/fnet/ init.py)

#

FNet: Mixing Tokens with Fourier Transforms

This is a PyTorch implementation of the paper FNet: Mixing Tokens with Fourier Transforms.

This paper replaces the self-attention layer with two Fourier transforms to mix tokens. This is a 7× more efficient than self-attention. The accuracy loss of using this over self-attention is about 92% for BERT on GLUE benchmark.

Mixing tokens with two Fourier transforms

We apply Fourier transform along the hidden dimension (embedding dimension) and then along the sequence dimension.

R(Fseq​(Fhidden​(x)))

where x is the embedding input, F stands for the fourier transform and R stands for the real component in complex numbers.

This is very simple to implement on PyTorch - just 1 line of code. The paper suggests using a precomputed DFT matrix and doing matrix multiplication to get the Fourier transformation.

Here is the training code for using a FNet based model for classifying AG News.

41fromtypingimportOptional4243importtorch44fromtorchimportnn

#

FNet - Mix tokens

This module simply implements R(Fseq​(Fhidden​(x)))

The structure of this module is made similar to a standard attention module so that we can simply replace it.

47classFNetMix(nn.Module):

#

The normal attention module can be fed with different token embeddings for query,key, and value and a mask.

We follow the same function signature so that we can replace it directly.

For FNet mixing, x=query=key=value and masking is not possible. Shape of query (and key and value ) is [seq_len, batch_size, d_model] .

60defforward(self,query:torch.Tensor,key:torch.Tensor,value:torch.Tensor,mask:Optional[torch.Tensor]=None):

#

query,key, and value all should be equal to x for token mixing

72assertqueryiskeyandkeyisvalue

#

Token mixing doesn't support masking. i.e. all tokens will see all other token embeddings.

74assertmaskisNone

#

Assign to x for clarity

77x=query

#

Apply the Fourier transform along the hidden (embedding) dimension Fhidden​(x)

The output of the Fourier transform is a tensor of complex numbers.

84fft\_hidden=torch.fft.fft(x,dim=2)

#

Apply the Fourier transform along the sequence dimension Fseq​(Fhidden​(x))

87fft\_seq=torch.fft.fft(fft\_hidden,dim=0)

#

Get the real component R(Fseq​(Fhidden​(x)))

91returntorch.real(fft\_seq)

labml.ai