Back to Annotated Deep Learning Paper Implementations

Evaluate k-nearest neighbor language model

docs/transformers/knn/eval_knn.html

latest6.0 KB
Original Source

hometransformersknn

View code on Github

#

Evaluate k-nearest neighbor language model

11fromtypingimportOptional,List1213importfaiss14importnumpyasnp15importtorch1617fromlabmlimportmonit,lab18fromlabml.loggerimportinspect19fromlabml\_nn.transformers.knn.train\_modelimportConfigs

#

k-NN to get p(wt​,ct​)

Here we refer to f(ct​) as queries, f(ci​) as keys and wi​ as values.

22defknn(queries:torch.Tensor,index:faiss.IndexFlatL2,keys\_store:np.ndarray,vals\_store:np.ndarray,n\_tokens:int):

#

Save shape of queries to reshape results

31queries\_shape=queries.shape

#

Flatten the batch and sequence dimensions of queries

34queries=queries.view(-1,queries\_shape[-1])

#

Find 10 nearest neighbors of f(ct​) among f(ci​). distance is the distance given by FAISS and idx , i is the index of it in keys_store .

38distance,idx=index.search(queries.numpy(),10)

#

Get f(ci​)

41keys\_found=queries.new\_tensor(keys\_store[idx])

#

Get wi​

43vals\_found=torch.tensor(vals\_store[idx]).squeeze(-1)

#

We are going to calculate the cosine similarity between normalized vectors

#

Normalize f(ci​)

48keys\_found\_n=keys\_found/torch.sqrt((keys\_found\*\*2).sum(-1,keepdims=True)+1e-10)

#

Normalize f(ct​)

50queries\_n=queries/torch.sqrt((queries\*\*2).sum(-1,keepdims=True)+1e-10)

#

Get the dot-product, or cosine similarity

53dot\_prod=(keys\_found\_n\*queries\_n.unsqueeze(1)).sum(-1)

#

Token-wise logits

56logits\_token=dot\_prod.new\_zeros(queries.shape[0],n\_tokens)

#

Scatter and accumulate token logits based on the nearest neighbors

58\_=logits\_token.scatter\_(dim=1,index=vals\_found,src=dot\_prod,reduce='add')

#

Reshape the logits

61logits\_token=logits\_token.reshape(queries\_shape[0],queries\_shape[1],-1)6263returnlogits\_token

#

Calculate validation loss

We calculate the validation loss of the combined on k-NN prediction and transformer prediction. The weight given to the k-NN model is given by knn_weight . It's a list of weights and we calculate the validation loss for each.

66defvalidation\_loss(knn\_weights:List[float],last\_n:Optional[int],conf:Configs,index:faiss.IndexFlatL2,67keys\_store:np.ndarray,vals\_store:np.ndarray):

#

List of losses for each knn_weights

77losses=[[]for\_inknn\_weights]

#

Number of samples in each batch

79n\_samples=[]80withtorch.no\_grad():

#

Iterate through validation data

82fori,batchinmonit.enum("Validation",conf.validator.data\_loader,is\_children\_silent=True):

#

Get data and target labels

84data,target=batch[0].to(conf.device),batch[1].to(conf.device)

#

Run the model and get predictions p(wt​,ct​)

86res=conf.model(data)

#

Get k-NN predictions

88res\_knn=knn(conf.model.ff\_input.cpu(),index,keys\_store,vals\_store,conf.n\_tokens)89res\_knn=res\_knn.to(conf.device)

#

This is to calculate only the loss for last_n tokens. This is important because the first predictions (along the sequence) of transformer model has very few past tokens to look at.

94iflast\_n:95res=res[-last\_n:]96res\_knn=res\_knn[-last\_n:]97target=target[-last\_n:]

#

Number of samples

100n\_s=res.shape[0]\*data.shape[1]101n\_samples.append(n\_s)

#

Calculate scores for each of knn_weights .

104fori,cinenumerate(knn\_weights):

#

Calculate the loss

106loss=conf.loss\_func(res\_knn\*c+(1-c)\*res,target)107losses[i].append(loss\*n\_s)108109returnlosses,n\_samples

#

Load the index

112defload\_index(conf:Configs,n\_probe:int=8):

#

Dimensions of f(ci​)

117d\_model=conf.transformer.d\_model

#

Training data loader

119data\_loader=conf.trainer.data\_loader

#

Number of contexts; i.e. number of tokens in the training data minus one. (f(ci​),wi​) for i∈[2,T]

122n\_keys=data\_loader.data.shape[0]\*data\_loader.data.shape[1]-1

#

Load FAISS index

125withmonit.section('Load index'):126index=faiss.read\_index(str(lab.get\_data\_path()/'faiss.index'))

#

Set number of cells to probe

128index.nprobe=n\_probe

#

Load memory mapped numpy arrays

131keys\_store=np.memmap(str(lab.get\_data\_path()/'keys.npy'),dtype=np.float32,mode='r',shape=(n\_keys,d\_model))132vals\_store=np.memmap(str(lab.get\_data\_path()/'vals.npy'),dtype=np.int,mode='r',shape=(n\_keys,1))133134returnindex,keys\_store,vals\_store

#

137defmain():138fromlabml\_nn.transformers.knn.build\_indeximportload\_experiment

#

Load the experiment. Replace the run uuid with you run uuid from training the model.

141conf=load\_experiment('4984b85c20bf11eb877a69c1a03717cd')

#

Set model to evaluation mode

143conf.model.eval()

#

Load index

146index,keys\_store,vals\_store=load\_index(conf)

#

List of weights given to k-NN prediction. We will evaluate the validation loss for each of the weights

149knn\_weights=[i/20foriinrange(10)]

#

Evaluate validation loss

151losses,n\_samples=validation\_loss(knn\_weights,None,conf,index,keys\_store,vals\_store)

#

Output the losses for each of knn_weights .

153inspect({c:np.sum(losses[i])/np.sum(n\_samples)fori,cinenumerate(knn\_weights)})154155156if\_\_name\_\_=='\_\_main\_\_':157main()

labml.ai