docs/sampling/experiment.html
This experiment uses the above sampling techniques, on HuggingFace's GPT2 model.
18importtorch1920fromlabmlimportmonit,logger,lab2122fromlabml.loggerimportText2324fromlabml\_nn.samplingimportSampler25fromlabml\_nn.sampling.greedyimportGreedySampler26fromlabml\_nn.sampling.nucleusimportNucleusSampler27fromlabml\_nn.sampling.temperatureimportTemperatureSampler28fromlabml\_nn.sampling.top\_kimportTopKSampler29fromtransformersimportGPT2Tokenizer,GPT2LMHeadModel
model is the model to sample fromtokenizer is the tokenizer to usesampler is the sampler to usen_samples is the number of samples to generaten_tokens is the number of tokens to generateseq_len is the maximum sequence length for the modelprompt is the starting prompt[email protected]\_grad()33defsample(model:GPT2LMHeadModel,tokenizer:GPT2Tokenizer,sampler:Sampler,34n\_samples:int,n\_tokens:int,seq\_len:int,prompt:str):
Tokenize the prompt and make n_samples copies of it
47data=torch.tile(torch.tensor(tokenizer.encode(prompt))[None,:],(n\_samples,1))
Collect output for printing
50logs=[[(prompt,Text.meta)]for\_inrange(n\_samples)]
Sample n_tokens
52foriinmonit.iterate('Sample',n\_tokens):
Truncate the data to the maximum sequence length
54data=data[-seq\_len:]
Get the model output. The 'logits' has shape [batch_size, seq_len, n_tokens]
56logits=model(data)[0]
Get the logits of the last token
58logits=logits[:,-1]
Sample from the logits
60res=sampler(logits)
Add the sampled token to the data
62data=torch.cat([data,res[:,None]],dim=1)
Decode and add the sampled token for logging
64forjinrange(n\_samples):65logs[j]+=[(''+tokenizer.decode(res[j]),Text.value)]
Print the sampled outputs
68forjinrange(n\_samples):69logger.log(logs[j])
72defmain():
Load the model and tokenizer
78withmonit.section('Load tokenizer/model'):79tokenizer=GPT2Tokenizer.from\_pretrained('gpt2',cache\_dir=lab.get\_data\_path()/'cache')80model=GPT2LMHeadModel.from\_pretrained('gpt2',cache\_dir=lab.get\_data\_path()/'cache')
Set the model to eval mode
82model.eval()
Prompts to use for sampling
85prompt='I saw an interesting dream last night. '
88withmonit.section('greedy'):89sample(model,tokenizer,GreedySampler(),4,32,128,prompt)
92withmonit.section('temperature=1.'):93sample(model,tokenizer,TemperatureSampler(1.),4,32,128,prompt)94withmonit.section('temperature=.1'):95sample(model,tokenizer,TemperatureSampler(.1),4,32,128,prompt)96withmonit.section('temperature=10.'):97sample(model,tokenizer,TemperatureSampler(10.),4,32,128,prompt)
100withmonit.section('top\_k=5'):101sample(model,tokenizer,TopKSampler(2,TemperatureSampler(1.)),4,32,128,prompt)
104withmonit.section('nucleus p=.95'):105sample(model,tokenizer,NucleusSampler(0.95,TemperatureSampler(1.)),4,32,128,prompt)106withmonit.section('nucleus p=.1'):107sample(model,tokenizer,NucleusSampler(0.1,TemperatureSampler(1.)),4,32,128,prompt)
110if\_\_name\_\_=='\_\_main\_\_':111main()