docs/transformers/knn/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/knn/ init.py)
This is a PyTorch implementation of the paper Generalization through Memorization: Nearest Neighbor Language Models. It uses k-nearest neighbors to improve perplexity of autoregressive transformer models.
An autoregressive language model estimates p(wt∣ct), where wt is the token at step t and ct is the context, ct=(w1,w2,...,wt−1).
This paper, improves p(wt∣ct) using a k-nearest neighbor search on key-value pairs (f(ci),wi), with search key f(ct). Here f(ct) is an embedding of the context ct. The paper (and this implementation) uses the input to the feed-forward layer of the final layer of the transformer as f(ct).
We use FAISS to index f(ci).
So to run kNN-LM we need to:
This experiment uses a small dataset so that we can run this without using up a few hundred giga-bytes of disk space for the index.
The official implementation of kNN-LM can be found here.