Back to Annotated Deep Learning Paper Implementations

Top-k Sampling

docs/sampling/top_k.html

latest1.5 KB
Original Source

homesampling

View code on Github

#

Top-k Sampling

Here we first pick the top-k tokens from the distribution of logits, and then sample from them.

Here's an experiment that uses these sampling techniques.

15importtorch1617fromlabml\_nn.samplingimportSampler

#

Top-k Sampler

20classTopKSampler(Sampler):

#

  • k is the number of tokens to pick
  • sampler is the sampler to use for the top-k tokens

sampler can be any sampler that takes a logits tensor as input and returns a token tensor; e.g. `TemperatureSampler'.

24def\_\_init\_\_(self,k:int,sampler:Sampler):

#

32self.k=k33self.sampler=sampler

#

Sample from logits

35def\_\_call\_\_(self,logits:torch.Tensor):

#

New logits filled with −∞; i.e. zero probability

40zeros=logits.new\_ones(logits.shape)\*float('-inf')

#

Pick the largest k logits and their indices

42values,indices=torch.topk(logits,self.k,dim=-1)

#

Set the values of the top-k selected indices to actual logits. Logits of other tokens remain −∞

45zeros.scatter\_(-1,indices,values)

#

Sample from the top-k logits with the specified sampler.

48returnself.sampler(zeros)

labml.ai