Back to Annotated Deep Learning Paper Implementations

k-Nearest Neighbor Language Models

docs/transformers/knn/index.html

latest1.6 KB
Original Source

hometransformersknn

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/knn/ init.py)

#

k-Nearest Neighbor Language Models

This is a PyTorch implementation of the paper Generalization through Memorization: Nearest Neighbor Language Models. It uses k-nearest neighbors to improve perplexity of autoregressive transformer models.

An autoregressive language model estimates p(wt​∣ct​), where wt​ is the token at step t and ct​ is the context, ct​=(w1​,w2​,...,wt−1​).

This paper, improves p(wt​∣ct​) using a k-nearest neighbor search on key-value pairs (f(ci​),wi​), with search key f(ct​). Here f(ct​) is an embedding of the context ct​. The paper (and this implementation) uses the input to the feed-forward layer of the final layer of the transformer as f(ct​).

We use FAISS to index f(ci​).

Implementation

So to run kNN-LM we need to:

This experiment uses a small dataset so that we can run this without using up a few hundred giga-bytes of disk space for the index.

The official implementation of kNN-LM can be found here.

labml.ai