Back to Annotated Deep Learning Paper Implementations

Database for nearest neighbor retrieval

docs/transformers/retro/database.html

latest5.9 KB
Original Source

hometransformersretro

View code on Github

#

Database for nearest neighbor retrieval

This is the build the database and retrieves nearest neighbors for RETRO model.

We use FAISS library for the database whilst the paper had used the SCaNN library.

16fromtypingimportList,Optional1718importfaiss19importnumpyasnp20importtorch2122fromlabmlimportlab,monit23fromlabml\_nn.helpers.datasetsimportTextFileDataset24fromlabml\_nn.transformers.retro.bert\_embeddingsimportBERTChunkEmbeddings

#

Build Database

  • chunk_len is the length of a chunk (number of characters)
  • batch_size is the batch size to use when calculating BERT(N)
  • d_emb is the number of features in BERT(N) embeddings lists to select in FAISS index
  • n_centeroids is the number of lists in the index
  • code_size encoded vector size in the index
  • n_probe is the number of lists to probe
  • `n_train' is the number of keys to train the index on
27defbuild\_database(chunk\_len:int=16,batch\_size:int=64,d\_emb:int=768,n\_centeroids:int=256,28code\_size:int=64,n\_probe:int=8,n\_train:int=50\_000):

#

Load the dataset text file

43dataset=TextFileDataset(44lab.get\_data\_path()/'tiny\_shakespeare.txt',45list,46url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')

#

Get training data (a string)

49text=dataset.train

#

Split the text into chunks of chunk_length

52chunks=[text[i:i+chunk\_len]foriinrange(0,len(text),chunk\_len)ifi+chunk\_len\*2\<len(text)]

#

Get the offsets of each of the chunks

54chunk\_offsets=np.array([iforiinrange(0,len(text),chunk\_len)ifi+chunk\_len\*2\<len(text)])

#

Number of chunks

56n\_chunks=len(chunks)

#

Initialize BERT to get BERT(N)

59bert=BERTChunkEmbeddings(torch.device('cuda:0'))

#

Get chunk embeddings by processing batch_size number of chunks on each iteration

62chunk\_emb=[]63foriinmonit.iterate('Get embeddings',range(0,n\_chunks,batch\_size)):64chunk\_emb.append(bert(chunks[i:i+batch\_size]).cpu())

#

Merge them into a single tensor

66chunk\_emb=torch.cat(chunk\_emb,dim=0).numpy()

#

Create the FAISS index

69quantizer=faiss.IndexFlatL2(d\_emb)70index=faiss.IndexIVFPQ(quantizer,d\_emb,n\_centeroids,code\_size,8)71index.nprobe=n\_probe

#

Get a random sample of the the chunk indexes

74random\_sample=np.random.choice(np.arange(n\_chunks),size=[min(n\_train,n\_chunks)],replace=False)

#

Train the index to store the keys

77withmonit.section('Train index'):78index.train(chunk\_emb[random\_sample])

#

Add the chunks to the index in batches of size 1024

81forsinmonit.iterate('Index',range(0,n\_chunks,1024)):82e=min(s+1024,n\_chunks)

#

Add to index

84index.add\_with\_ids(chunk\_emb[s:e],chunk\_offsets[s:e])

#

Save the index

87withmonit.section('Save'):88faiss.write\_index(index,str(lab.get\_data\_path()/'retro.index'))

#

Index for retrieving nearest neighbors

91classRetroIndex:

#

  • chunk_len is the chunk length
  • n_probe is the number of lists to probe
  • n_neighbors is the number of neighbors to retrieve
  • n_extra is the number of extra neighbors to retrieve since we will be removing neighbors overlapping with the query chunk
  • exclude_neighbor_span is the extra text length to avoid when checking for overlaps
96def\_\_init\_\_(self,chunk\_len:int=16,n\_probe:int=8,97n\_neighbors:int=2,n\_extra:int=2,98exclude\_neighbor\_span:int=8):

#

108self.n\_neighbors=n\_neighbors109self.chunk\_len=chunk\_len110self.exclude\_neighbor\_span=exclude\_neighbor\_span111self.n\_extra=n\_extra

#

Initialize BERT to get BERT(N)

114self.bert=BERTChunkEmbeddings(torch.device('cuda:0'))

#

Load the database

116withmonit.section('Load index'):117self.index=faiss.read\_index(str(lab.get\_data\_path()/'retro.index'))118self.index.nprobe=n\_probe

#

Filter neighbors that overlap with the query

The positions of the neighbors are given by neighbor_offsets and the position of the query chunk is offset .

120deffilter\_neighbors(self,offset:int,neighbor\_offsets:List[int]):

#

127return[nforninneighbor\_offsets128ifn\<offset-(self.chunk\_len+self.exclude\_neighbor\_span)129orn\>offset+(self.chunk\_len+self.exclude\_neighbor\_span)]

#

Retrieve nearest neighbors

131def\_\_call\_\_(self,query\_chunks:List[str],offsets:Optional[List[int]]):

#

Get BERT(N) of query chunks

137emb=self.bert(query\_chunks).cpu()

#

Get n_neighbors + n_extra nearest neighbors from the database

140distance,neighbor\_offsets=self.index.search(emb.numpy(),self.n\_neighbors+self.n\_extra)

#

If the query chunk offsets are given filter out overlapping chunks

143ifoffsetsisnotNone:144neighbor\_offsets=[self.filter\_neighbors(off,n\_off)145foroff,n\_offinzip(offsets,neighbor\_offsets)]

#

Get the closest n_neighbors after filtering

148neighbor\_offsets=[n\_off[:self.n\_neighbors]forn\_offinneighbor\_offsets]

#

151returnneighbor\_offsets

#

155if\_\_name\_\_=='\_\_main\_\_':156build\_database()

labml.ai