docs/sampling/top_k.html
Here we first pick the top-k tokens from the distribution of logits, and then sample from them.
Here's an experiment that uses these sampling techniques.
15importtorch1617fromlabml\_nn.samplingimportSampler
20classTopKSampler(Sampler):
k is the number of tokens to picksampler is the sampler to use for the top-k tokenssampler can be any sampler that takes a logits tensor as input and returns a token tensor; e.g. `TemperatureSampler'.
24def\_\_init\_\_(self,k:int,sampler:Sampler):
32self.k=k33self.sampler=sampler
Sample from logits
35def\_\_call\_\_(self,logits:torch.Tensor):
New logits filled with −∞; i.e. zero probability
40zeros=logits.new\_ones(logits.shape)\*float('-inf')
Pick the largest k logits and their indices
42values,indices=torch.topk(logits,self.k,dim=-1)
Set the values of the top-k selected indices to actual logits. Logits of other tokens remain −∞
45zeros.scatter\_(-1,indices,values)
Sample from the top-k logits with the specified sampler.
48returnself.sampler(zeros)