docs/transformers/knn/eval_knn.html
11fromtypingimportOptional,List1213importfaiss14importnumpyasnp15importtorch1617fromlabmlimportmonit,lab18fromlabml.loggerimportinspect19fromlabml\_nn.transformers.knn.train\_modelimportConfigs
Here we refer to f(ct) as queries, f(ci) as keys and wi as values.
22defknn(queries:torch.Tensor,index:faiss.IndexFlatL2,keys\_store:np.ndarray,vals\_store:np.ndarray,n\_tokens:int):
Save shape of queries to reshape results
31queries\_shape=queries.shape
Flatten the batch and sequence dimensions of queries
34queries=queries.view(-1,queries\_shape[-1])
Find 10 nearest neighbors of f(ct) among f(ci). distance is the distance given by FAISS and idx , i is the index of it in keys_store .
38distance,idx=index.search(queries.numpy(),10)
Get f(ci)
41keys\_found=queries.new\_tensor(keys\_store[idx])
Get wi
43vals\_found=torch.tensor(vals\_store[idx]).squeeze(-1)
We are going to calculate the cosine similarity between normalized vectors
Normalize f(ci)
48keys\_found\_n=keys\_found/torch.sqrt((keys\_found\*\*2).sum(-1,keepdims=True)+1e-10)
Normalize f(ct)
50queries\_n=queries/torch.sqrt((queries\*\*2).sum(-1,keepdims=True)+1e-10)
Get the dot-product, or cosine similarity
53dot\_prod=(keys\_found\_n\*queries\_n.unsqueeze(1)).sum(-1)
Token-wise logits
56logits\_token=dot\_prod.new\_zeros(queries.shape[0],n\_tokens)
Scatter and accumulate token logits based on the nearest neighbors
58\_=logits\_token.scatter\_(dim=1,index=vals\_found,src=dot\_prod,reduce='add')
Reshape the logits
61logits\_token=logits\_token.reshape(queries\_shape[0],queries\_shape[1],-1)6263returnlogits\_token
We calculate the validation loss of the combined on k-NN prediction and transformer prediction. The weight given to the k-NN model is given by knn_weight . It's a list of weights and we calculate the validation loss for each.
66defvalidation\_loss(knn\_weights:List[float],last\_n:Optional[int],conf:Configs,index:faiss.IndexFlatL2,67keys\_store:np.ndarray,vals\_store:np.ndarray):
List of losses for each knn_weights
77losses=[[]for\_inknn\_weights]
Number of samples in each batch
79n\_samples=[]80withtorch.no\_grad():
Iterate through validation data
82fori,batchinmonit.enum("Validation",conf.validator.data\_loader,is\_children\_silent=True):
Get data and target labels
84data,target=batch[0].to(conf.device),batch[1].to(conf.device)
Run the model and get predictions p(wt,ct)
86res=conf.model(data)
Get k-NN predictions
88res\_knn=knn(conf.model.ff\_input.cpu(),index,keys\_store,vals\_store,conf.n\_tokens)89res\_knn=res\_knn.to(conf.device)
This is to calculate only the loss for last_n tokens. This is important because the first predictions (along the sequence) of transformer model has very few past tokens to look at.
94iflast\_n:95res=res[-last\_n:]96res\_knn=res\_knn[-last\_n:]97target=target[-last\_n:]
Number of samples
100n\_s=res.shape[0]\*data.shape[1]101n\_samples.append(n\_s)
Calculate scores for each of knn_weights .
104fori,cinenumerate(knn\_weights):
Calculate the loss
106loss=conf.loss\_func(res\_knn\*c+(1-c)\*res,target)107losses[i].append(loss\*n\_s)108109returnlosses,n\_samples
112defload\_index(conf:Configs,n\_probe:int=8):
Dimensions of f(ci)
117d\_model=conf.transformer.d\_model
Training data loader
119data\_loader=conf.trainer.data\_loader
Number of contexts; i.e. number of tokens in the training data minus one. (f(ci),wi) for i∈[2,T]
122n\_keys=data\_loader.data.shape[0]\*data\_loader.data.shape[1]-1
Load FAISS index
125withmonit.section('Load index'):126index=faiss.read\_index(str(lab.get\_data\_path()/'faiss.index'))
Set number of cells to probe
128index.nprobe=n\_probe
Load memory mapped numpy arrays
131keys\_store=np.memmap(str(lab.get\_data\_path()/'keys.npy'),dtype=np.float32,mode='r',shape=(n\_keys,d\_model))132vals\_store=np.memmap(str(lab.get\_data\_path()/'vals.npy'),dtype=np.int,mode='r',shape=(n\_keys,1))133134returnindex,keys\_store,vals\_store
137defmain():138fromlabml\_nn.transformers.knn.build\_indeximportload\_experiment
Load the experiment. Replace the run uuid with you run uuid from training the model.
141conf=load\_experiment('4984b85c20bf11eb877a69c1a03717cd')
Set model to evaluation mode
143conf.model.eval()
Load index
146index,keys\_store,vals\_store=load\_index(conf)
List of weights given to k-NN prediction. We will evaluate the validation loss for each of the weights
149knn\_weights=[i/20foriinrange(10)]
Evaluate validation loss
151losses,n\_samples=validation\_loss(knn\_weights,None,conf,index,keys\_store,vals\_store)
Output the losses for each of knn_weights .
153inspect({c:np.sum(losses[i])/np.sum(n\_samples)fori,cinenumerate(knn\_weights)})154155156if\_\_name\_\_=='\_\_main\_\_':157main()