docs/transformers/retro/database.html
This is the build the database and retrieves nearest neighbors for RETRO model.
We use FAISS library for the database whilst the paper had used the SCaNN library.
16fromtypingimportList,Optional1718importfaiss19importnumpyasnp20importtorch2122fromlabmlimportlab,monit23fromlabml\_nn.helpers.datasetsimportTextFileDataset24fromlabml\_nn.transformers.retro.bert\_embeddingsimportBERTChunkEmbeddings
chunk_len is the length of a chunk (number of characters)batch_size is the batch size to use when calculating BERT(N)d_emb is the number of features in BERT(N) embeddings lists to select in FAISS indexn_centeroids is the number of lists in the indexcode_size encoded vector size in the indexn_probe is the number of lists to probe27defbuild\_database(chunk\_len:int=16,batch\_size:int=64,d\_emb:int=768,n\_centeroids:int=256,28code\_size:int=64,n\_probe:int=8,n\_train:int=50\_000):
Load the dataset text file
43dataset=TextFileDataset(44lab.get\_data\_path()/'tiny\_shakespeare.txt',45list,46url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
Get training data (a string)
49text=dataset.train
Split the text into chunks of chunk_length
52chunks=[text[i:i+chunk\_len]foriinrange(0,len(text),chunk\_len)ifi+chunk\_len\*2\<len(text)]
Get the offsets of each of the chunks
54chunk\_offsets=np.array([iforiinrange(0,len(text),chunk\_len)ifi+chunk\_len\*2\<len(text)])
Number of chunks
56n\_chunks=len(chunks)
Initialize BERT to get BERT(N)
59bert=BERTChunkEmbeddings(torch.device('cuda:0'))
Get chunk embeddings by processing batch_size number of chunks on each iteration
62chunk\_emb=[]63foriinmonit.iterate('Get embeddings',range(0,n\_chunks,batch\_size)):64chunk\_emb.append(bert(chunks[i:i+batch\_size]).cpu())
Merge them into a single tensor
66chunk\_emb=torch.cat(chunk\_emb,dim=0).numpy()
Create the FAISS index
69quantizer=faiss.IndexFlatL2(d\_emb)70index=faiss.IndexIVFPQ(quantizer,d\_emb,n\_centeroids,code\_size,8)71index.nprobe=n\_probe
Get a random sample of the the chunk indexes
74random\_sample=np.random.choice(np.arange(n\_chunks),size=[min(n\_train,n\_chunks)],replace=False)
Train the index to store the keys
77withmonit.section('Train index'):78index.train(chunk\_emb[random\_sample])
Add the chunks to the index in batches of size 1024
81forsinmonit.iterate('Index',range(0,n\_chunks,1024)):82e=min(s+1024,n\_chunks)
Add to index
84index.add\_with\_ids(chunk\_emb[s:e],chunk\_offsets[s:e])
Save the index
87withmonit.section('Save'):88faiss.write\_index(index,str(lab.get\_data\_path()/'retro.index'))
91classRetroIndex:
chunk_len is the chunk lengthn_probe is the number of lists to proben_neighbors is the number of neighbors to retrieven_extra is the number of extra neighbors to retrieve since we will be removing neighbors overlapping with the query chunkexclude_neighbor_span is the extra text length to avoid when checking for overlaps96def\_\_init\_\_(self,chunk\_len:int=16,n\_probe:int=8,97n\_neighbors:int=2,n\_extra:int=2,98exclude\_neighbor\_span:int=8):
108self.n\_neighbors=n\_neighbors109self.chunk\_len=chunk\_len110self.exclude\_neighbor\_span=exclude\_neighbor\_span111self.n\_extra=n\_extra
Initialize BERT to get BERT(N)
114self.bert=BERTChunkEmbeddings(torch.device('cuda:0'))
Load the database
116withmonit.section('Load index'):117self.index=faiss.read\_index(str(lab.get\_data\_path()/'retro.index'))118self.index.nprobe=n\_probe
The positions of the neighbors are given by neighbor_offsets and the position of the query chunk is offset .
120deffilter\_neighbors(self,offset:int,neighbor\_offsets:List[int]):
127return[nforninneighbor\_offsets128ifn\<offset-(self.chunk\_len+self.exclude\_neighbor\_span)129orn\>offset+(self.chunk\_len+self.exclude\_neighbor\_span)]
131def\_\_call\_\_(self,query\_chunks:List[str],offsets:Optional[List[int]]):
Get BERT(N) of query chunks
137emb=self.bert(query\_chunks).cpu()
Get n_neighbors + n_extra nearest neighbors from the database
140distance,neighbor\_offsets=self.index.search(emb.numpy(),self.n\_neighbors+self.n\_extra)
If the query chunk offsets are given filter out overlapping chunks
143ifoffsetsisnotNone:144neighbor\_offsets=[self.filter\_neighbors(off,n\_off)145foroff,n\_offinzip(offsets,neighbor\_offsets)]
Get the closest n_neighbors after filtering
148neighbor\_offsets=[n\_off[:self.n\_neighbors]forn\_offinneighbor\_offsets]
151returnneighbor\_offsets
155if\_\_name\_\_=='\_\_main\_\_':156build\_database()