Back to Annotated Deep Learning Paper Implementations

Nucleus Sampling

docs/sampling/nucleus.html

latest2.8 KB
Original Source

homesampling

View code on Github

#

Nucleus Sampling

This is an implementation of nucleus sampling, introduced in the paper The Curious Case of Neural Text Degeneration.

The paper discusses the problems with other sampling methods such as Beam Search, Pure sampling, Temperature sampling, and Top-k sampling. The paper introduces the idea of nucleus sampling, which practically performs better than other sampling methods for text generation.

Nucleus sampling first picks a subset of the vocabulary V(p)⊂V, where V(p) is smallest set of tokens such that

xi​∈V(p)∑​P(xi​∣x1:i−1​)≥p

That is, we pick the highest probable tokens until the sum of their probabilities is less that p.

Then we sample from the selected tokens.

Here's an experiment that uses these sampling techniques.

29importtorch30fromtorchimportnn3132fromlabml\_nn.samplingimportSampler

#

Nucleus Sampler

35classNucleusSampler(Sampler):

#

  • p is the sum of probabilities of tokens to pick p
  • sampler is the sampler to use for the selected tokens
39def\_\_init\_\_(self,p:float,sampler:Sampler):

#

44self.p=p45self.sampler=sampler

#

Softmax to compute P(xi​∣x1:i−1​) from the logits

47self.softmax=nn.Softmax(dim=-1)

#

Sample from logits with Nucleus Sampling

49def\_\_call\_\_(self,logits:torch.Tensor):

#

Get probabilities P(xi​∣x1:i−1​)

55probs=self.softmax(logits)

#

Sort probabilities in descending order

58sorted\_probs,indices=torch.sort(probs,dim=-1,descending=True)

#

Get the cumulative sum of probabilities in the sorted order

60cum\_sum\_probs=torch.cumsum(sorted\_probs,dim=-1)

#

Find the cumulative sums less than p.

62nucleus=cum\_sum\_probs\<self.p

#

Prepend ones so that we add one token after the minimum number of tokens with cumulative probability less that p.

65nucleus=torch.cat([nucleus.new\_ones(nucleus.shape[:-1]+(1,)),nucleus[...,:-1]],dim=-1)

#

Get log probabilities and mask out the non-nucleus

68sorted\_log\_probs=torch.log(sorted\_probs)69sorted\_log\_probs[~nucleus]=float('-inf')

#

Sample from the sampler

72sampled\_sorted\_indexes=self.sampler(sorted\_log\_probs)

#

Get the actual indexes

75res=indices.gather(-1,sampled\_sorted\_indexes.unsqueeze(-1))

#

78returnres.squeeze(-1)

labml.ai