docs/transformers/knn/build_index.html
We want to build the index of (f(ci),wi). We store f(ci) and wi in memory mapped numpy arrays. We find f(ci) nearest to f(ct) using FAISS. FAISS indexes (f(ci),i) and we query it with f(ct).
15fromtypingimportOptional1617importfaiss18importnumpyasnp19importtorch2021fromlabmlimportexperiment,monit,lab22fromlabml.utils.pytorchimportget\_modules23fromlabml\_nn.transformers.knn.train\_modelimportConfigs
Load a saved experiment from train model.
26defload\_experiment(run\_uuid:str,checkpoint:Optional[int]=None):
Create configurations object
32conf=Configs()
Load custom configurations used in the experiment
34conf\_dict=experiment.load\_configs(run\_uuid)
We need to get inputs to the feed forward layer, f(ci)
36conf\_dict['is\_save\_ff\_input']=True
This experiment is just an evaluation; i.e. nothing is tracked or saved
39experiment.evaluate()
Initialize configurations
41experiment.configs(conf,conf\_dict)
Set models for saving/loading
43experiment.add\_pytorch\_models(get\_modules(conf))
Specify the experiment to load from
45experiment.load(run\_uuid,checkpoint)
Start the experiment; this is when it actually loads models
48experiment.start()4950returnconf
Note that these numpy arrays will take up a lot of space (even few hundred gigabytes) depending on the size of your dataset.
53defgather\_keys(conf:Configs):
Dimensions of f(ci)
62d\_model=conf.transformer.d\_model
Training data loader
64data\_loader=conf.trainer.data\_loader
Number of contexts; i.e. number of tokens in the training data minus one. (f(ci),wi) for i∈[2,T]
67n\_keys=data\_loader.data.shape[0]\*data\_loader.data.shape[1]-1
Numpy array for f(ci)
69keys\_store=np.memmap(str(lab.get\_data\_path()/'keys.npy'),dtype=np.float32,mode='w+',shape=(n\_keys,d\_model))
Numpy array for wi
71vals\_store=np.memmap(str(lab.get\_data\_path()/'vals.npy'),dtype=np.int,mode='w+',shape=(n\_keys,1))
Number of keys f(ci) collected
74added=075withtorch.no\_grad():
Loop through data
77fori,batchinmonit.enum("Collect data",data\_loader,is\_children\_silent=True):
wi the target labels
79vals=batch[1].view(-1,1)
Input data moved to the device of the model
81data=batch[0].to(conf.device)
Run the model
83\_=conf.model(data)
Get f(ci)
85keys=conf.model.ff\_input.view(-1,d\_model)
Save keys, f(ci) in the memory mapped numpy array
87keys\_store[added:added+keys.shape[0]]=keys.cpu()
Save values, wi in the memory mapped numpy array
89vals\_store[added:added+keys.shape[0]]=vals
Increment the number of collected keys
91added+=keys.shape[0]
Getting started, faster search, and lower memory footprint tutorials on FAISS will help you learn more about FAISS usage.
94defbuild\_index(conf:Configs,n\_centeroids:int=2048,code\_size:int=64,n\_probe:int=8,n\_train:int=200\_000):
Dimensions of f(ci)
104d\_model=conf.transformer.d\_model
Training data loader
106data\_loader=conf.trainer.data\_loader
Number of contexts; i.e. number of tokens in the training data minus one. (f(ci),wi) for i∈[2,T]
109n\_keys=data\_loader.data.shape[0]\*data\_loader.data.shape[1]-1
Build an index with Verenoi cell based faster search with compression that doesn't store full vectors.
113quantizer=faiss.IndexFlatL2(d\_model)114index=faiss.IndexIVFPQ(quantizer,d\_model,n\_centeroids,code\_size,8)115index.nprobe=n\_probe
Load the memory mapped numpy array of keys
118keys\_store=np.memmap(str(lab.get\_data\_path()/'keys.npy'),dtype=np.float32,mode='r',shape=(n\_keys,d\_model))
Pick a random sample of keys to train the index with
121random\_sample=np.random.choice(np.arange(n\_keys),size=[min(n\_train,n\_keys)],replace=False)122123withmonit.section('Train index'):
Train the index to store the keys
125index.train(keys\_store[random\_sample])
Add keys to the index; (f(ci),i)
128forsinmonit.iterate('Index',range(0,n\_keys,1024)):129e=min(s+1024,n\_keys)
f(ci)
131keys=keys\_store[s:e]
i
133idx=np.arange(s,e)
Add to index
135index.add\_with\_ids(keys,idx)136137withmonit.section('Save'):
Save the index
139faiss.write\_index(index,str(lab.get\_data\_path()/'faiss.index'))
142defmain():
Load the experiment. Replace the run uuid with you run uuid from training the model.
145conf=load\_experiment('4984b85c20bf11eb877a69c1a03717cd')
Set model to evaluation mode
147conf.model.eval()
Collect (f(ci),wi)
150gather\_keys(conf)
Add them to the index for fast search
152build\_index(conf)153154155if\_\_name\_\_=='\_\_main\_\_':156main()