Back to Annotated Deep Learning Paper Implementations

experiment_tiny.py

docs/sampling/experiment_tiny.html

latest2.8 KB
Original Source

homesampling

View code on Github

#

1fromtypingimportTuple23importtorch45fromlabmlimportexperiment,monit6fromlabmlimportlogger7fromlabml.loggerimportText8fromlabml\_nn.helpers.datasetsimportTextDataset9fromlabml\_nn.samplingimportSampler10fromlabml\_nn.sampling.greedyimportGreedySampler11fromlabml\_nn.sampling.nucleusimportNucleusSampler12fromlabml\_nn.sampling.temperatureimportTemperatureSampler13fromlabml\_nn.sampling.top\_kimportTopKSampler14fromlabml\_nn.transformers.basic.autoregressive\_experimentimportConfigs,AutoregressiveTransformer

#

17defget\_model\_dataset(run\_uuid:str)-\>Tuple[AutoregressiveTransformer,TextDataset]:18experiment.evaluate()1920conf=Configs()2122experiment.configs(conf,experiment.load\_configs(run\_uuid))2324experiment.load(run\_uuid)2526experiment.add\_pytorch\_models({'model':conf.model})2728experiment.start()2930returnconf.model,conf.text

#

33defsample(model,ds,sampler:Sampler,n\_samples:int,n\_tokens:int,seq\_len:int,prompt:str):34withtorch.no\_grad():35data=torch.tile(ds.text\_to\_i(prompt)[:,None],(1,n\_samples))

#

Collect output for printing

38logs=[[(prompt,Text.meta)]for\_inrange(n\_samples)]

#

Sample 25 tokens

40foriinmonit.iterate('Sample',n\_tokens):

#

Tokenize the prompt

42data=data[-seq\_len:]

#

Get the model output

44logits,\*\_=model(data)45logits=logits[-1]

#

Get the model prediction (greedy)

47res=sampler(logits)48data=torch.cat([data,res[None,:]],dim=0)

#

Add the prediction for logging

50forjinrange(n\_samples):51logs[j]+=[(''+ds.itos[res[j]],Text.value)]

#

Print the sampled output

54forjinrange(n\_samples):55logger.log(logs[j])

#

58defmain():59model,ds=get\_model\_dataset('074d4004cc6b11ecad7a0242ac1c0002')60model.eval()6162withmonit.section('greedy'):63sample(model,ds,GreedySampler(),4,32,128,'It is')6465withmonit.section('temperature=1.'):66sample(model,ds,TemperatureSampler(1.),4,32,128,'It is')67withmonit.section('temperature=.1'):68sample(model,ds,TemperatureSampler(.1),4,32,128,'It is')69withmonit.section('temperature=10.'):70sample(model,ds,TemperatureSampler(10.),4,32,128,'It is')7172withmonit.section('top\_k=5'):73sample(model,ds,TopKSampler(2,TemperatureSampler(1.)),4,32,128,'It is')7475withmonit.section('nucles p=.95'):76sample(model,ds,NucleusSampler(0.95,TemperatureSampler(1.)),4,32,128,'It is')77withmonit.section('nucles p=.95'):78sample(model,ds,NucleusSampler(0.1,TemperatureSampler(1.)),4,32,128,'It is')798081if\_\_name\_\_=='\_\_main\_\_':82main()

labml.ai