Back to Annotated Deep Learning Paper Implementations

Build FAISS index for k-NN search

docs/transformers/knn/build_index.html

latest6.2 KB
Original Source

hometransformersknn

View code on Github

#

Build FAISS index for k-NN search

We want to build the index of (f(ci​),wi​). We store f(ci​) and wi​ in memory mapped numpy arrays. We find f(ci​) nearest to f(ct​) using FAISS. FAISS indexes (f(ci​),i) and we query it with f(ct​).

15fromtypingimportOptional1617importfaiss18importnumpyasnp19importtorch2021fromlabmlimportexperiment,monit,lab22fromlabml.utils.pytorchimportget\_modules23fromlabml\_nn.transformers.knn.train\_modelimportConfigs

#

Load a saved experiment from train model.

26defload\_experiment(run\_uuid:str,checkpoint:Optional[int]=None):

#

Create configurations object

32conf=Configs()

#

Load custom configurations used in the experiment

34conf\_dict=experiment.load\_configs(run\_uuid)

#

We need to get inputs to the feed forward layer, f(ci​)

36conf\_dict['is\_save\_ff\_input']=True

#

This experiment is just an evaluation; i.e. nothing is tracked or saved

39experiment.evaluate()

#

Initialize configurations

41experiment.configs(conf,conf\_dict)

#

Set models for saving/loading

43experiment.add\_pytorch\_models(get\_modules(conf))

#

Specify the experiment to load from

45experiment.load(run\_uuid,checkpoint)

#

Start the experiment; this is when it actually loads models

48experiment.start()4950returnconf

#

Gather (f(ci​),wi​) and save them in numpy arrays

Note that these numpy arrays will take up a lot of space (even few hundred gigabytes) depending on the size of your dataset.

53defgather\_keys(conf:Configs):

#

Dimensions of f(ci​)

62d\_model=conf.transformer.d\_model

#

Training data loader

64data\_loader=conf.trainer.data\_loader

#

Number of contexts; i.e. number of tokens in the training data minus one. (f(ci​),wi​) for i∈[2,T]

67n\_keys=data\_loader.data.shape[0]\*data\_loader.data.shape[1]-1

#

Numpy array for f(ci​)

69keys\_store=np.memmap(str(lab.get\_data\_path()/'keys.npy'),dtype=np.float32,mode='w+',shape=(n\_keys,d\_model))

#

Numpy array for wi​

71vals\_store=np.memmap(str(lab.get\_data\_path()/'vals.npy'),dtype=np.int,mode='w+',shape=(n\_keys,1))

#

Number of keys f(ci​) collected

74added=075withtorch.no\_grad():

#

Loop through data

77fori,batchinmonit.enum("Collect data",data\_loader,is\_children\_silent=True):

#

wi​ the target labels

79vals=batch[1].view(-1,1)

#

Input data moved to the device of the model

81data=batch[0].to(conf.device)

#

Run the model

83\_=conf.model(data)

#

Get f(ci​)

85keys=conf.model.ff\_input.view(-1,d\_model)

#

Save keys, f(ci​) in the memory mapped numpy array

87keys\_store[added:added+keys.shape[0]]=keys.cpu()

#

Save values, wi​ in the memory mapped numpy array

89vals\_store[added:added+keys.shape[0]]=vals

#

Increment the number of collected keys

91added+=keys.shape[0]

#

Build FAISS index

Getting started, faster search, and lower memory footprint tutorials on FAISS will help you learn more about FAISS usage.

94defbuild\_index(conf:Configs,n\_centeroids:int=2048,code\_size:int=64,n\_probe:int=8,n\_train:int=200\_000):

#

Dimensions of f(ci​)

104d\_model=conf.transformer.d\_model

#

Training data loader

106data\_loader=conf.trainer.data\_loader

#

Number of contexts; i.e. number of tokens in the training data minus one. (f(ci​),wi​) for i∈[2,T]

109n\_keys=data\_loader.data.shape[0]\*data\_loader.data.shape[1]-1

#

Build an index with Verenoi cell based faster search with compression that doesn't store full vectors.

113quantizer=faiss.IndexFlatL2(d\_model)114index=faiss.IndexIVFPQ(quantizer,d\_model,n\_centeroids,code\_size,8)115index.nprobe=n\_probe

#

Load the memory mapped numpy array of keys

118keys\_store=np.memmap(str(lab.get\_data\_path()/'keys.npy'),dtype=np.float32,mode='r',shape=(n\_keys,d\_model))

#

Pick a random sample of keys to train the index with

121random\_sample=np.random.choice(np.arange(n\_keys),size=[min(n\_train,n\_keys)],replace=False)122123withmonit.section('Train index'):

#

Train the index to store the keys

125index.train(keys\_store[random\_sample])

#

Add keys to the index; (f(ci​),i)

128forsinmonit.iterate('Index',range(0,n\_keys,1024)):129e=min(s+1024,n\_keys)

#

f(ci​)

131keys=keys\_store[s:e]

#

i

133idx=np.arange(s,e)

#

Add to index

135index.add\_with\_ids(keys,idx)136137withmonit.section('Save'):

#

Save the index

139faiss.write\_index(index,str(lab.get\_data\_path()/'faiss.index'))

#

142defmain():

#

Load the experiment. Replace the run uuid with you run uuid from training the model.

145conf=load\_experiment('4984b85c20bf11eb877a69c1a03717cd')

#

Set model to evaluation mode

147conf.model.eval()

#

Collect (f(ci​),wi​)

150gather\_keys(conf)

#

Add them to the index for fast search

152build\_index(conf)153154155if\_\_name\_\_=='\_\_main\_\_':156main()

labml.ai