Back to Annotated Deep Learning Paper Implementations

Trying out Sampling Techniques for Language Models

docs/sampling/experiment.html

latest4.0 KB
Original Source

homesampling

View code on Github

#

Trying out Sampling Techniques for Language Models

This experiment uses the above sampling techniques, on HuggingFace's GPT2 model.

18importtorch1920fromlabmlimportmonit,logger,lab2122fromlabml.loggerimportText2324fromlabml\_nn.samplingimportSampler25fromlabml\_nn.sampling.greedyimportGreedySampler26fromlabml\_nn.sampling.nucleusimportNucleusSampler27fromlabml\_nn.sampling.temperatureimportTemperatureSampler28fromlabml\_nn.sampling.top\_kimportTopKSampler29fromtransformersimportGPT2Tokenizer,GPT2LMHeadModel

#

Sample from model

  • model is the model to sample from
  • tokenizer is the tokenizer to use
  • sampler is the sampler to use
  • n_samples is the number of samples to generate
  • n_tokens is the number of tokens to generate
  • seq_len is the maximum sequence length for the model
  • prompt is the starting prompt
[email protected]\_grad()33defsample(model:GPT2LMHeadModel,tokenizer:GPT2Tokenizer,sampler:Sampler,34n\_samples:int,n\_tokens:int,seq\_len:int,prompt:str):

#

Tokenize the prompt and make n_samples copies of it

47data=torch.tile(torch.tensor(tokenizer.encode(prompt))[None,:],(n\_samples,1))

#

Collect output for printing

50logs=[[(prompt,Text.meta)]for\_inrange(n\_samples)]

#

Sample n_tokens

52foriinmonit.iterate('Sample',n\_tokens):

#

Truncate the data to the maximum sequence length

54data=data[-seq\_len:]

#

Get the model output. The 'logits' has shape [batch_size, seq_len, n_tokens]

56logits=model(data)[0]

#

Get the logits of the last token

58logits=logits[:,-1]

#

Sample from the logits

60res=sampler(logits)

#

Add the sampled token to the data

62data=torch.cat([data,res[:,None]],dim=1)

#

Decode and add the sampled token for logging

64forjinrange(n\_samples):65logs[j]+=[(''+tokenizer.decode(res[j]),Text.value)]

#

Print the sampled outputs

68forjinrange(n\_samples):69logger.log(logs[j])

#

Try different sampling techniques

72defmain():

#

Load the model and tokenizer

78withmonit.section('Load tokenizer/model'):79tokenizer=GPT2Tokenizer.from\_pretrained('gpt2',cache\_dir=lab.get\_data\_path()/'cache')80model=GPT2LMHeadModel.from\_pretrained('gpt2',cache\_dir=lab.get\_data\_path()/'cache')

#

Set the model to eval mode

82model.eval()

#

Prompts to use for sampling

85prompt='I saw an interesting dream last night. '

#

Greedy Sampling

88withmonit.section('greedy'):89sample(model,tokenizer,GreedySampler(),4,32,128,prompt)

#

Temperature Sampling

92withmonit.section('temperature=1.'):93sample(model,tokenizer,TemperatureSampler(1.),4,32,128,prompt)94withmonit.section('temperature=.1'):95sample(model,tokenizer,TemperatureSampler(.1),4,32,128,prompt)96withmonit.section('temperature=10.'):97sample(model,tokenizer,TemperatureSampler(10.),4,32,128,prompt)

#

Top-k Sampling

100withmonit.section('top\_k=5'):101sample(model,tokenizer,TopKSampler(2,TemperatureSampler(1.)),4,32,128,prompt)

#

Nucleus Sampling

104withmonit.section('nucleus p=.95'):105sample(model,tokenizer,NucleusSampler(0.95,TemperatureSampler(1.)),4,32,128,prompt)106withmonit.section('nucleus p=.1'):107sample(model,tokenizer,NucleusSampler(0.1,TemperatureSampler(1.)),4,32,128,prompt)

#

110if\_\_name\_\_=='\_\_main\_\_':111main()

labml.ai