Back to Annotated Deep Learning Paper Implementations

RETRO training

docs/transformers/retro/train.html

latest7.2 KB
Original Source

hometransformersretro

View code on Github

#

RETRO training

This is the training code for RETRO.

14importtorch15fromlabmlimportmonit,lab,tracker,experiment,logger16fromlabml.loggerimportText17fromlabml\_nn.helpers.datasetsimportTextFileDataset18fromlabml\_nn.optimizers.noamimportNoam19fromlabml\_nn.transformers.retroimportmodelasretro20fromlabml\_nn.transformers.retro.datasetimportDataset,RetroIndex21fromlabml\_nn.transformers.retro.modelimportRetroModel,NearestNeighborEncoder22fromtorchimportnn23fromtorch.utils.dataimportDataLoader,RandomSampler

#

Sampler

This class greedily samples from a model.

26classSampler:

#

  • device is the device of the model
  • model is the Retro mode
  • tds is the text dataset (used to get neighbor chunks)
  • chunk_len is the length of a chunk
33def\_\_init\_\_(self,device:torch.device,model:retro.RetroModel,tds:TextFileDataset,chunk\_len:int):

#

40self.chunk\_len=chunk\_len41self.tds=tds42self.model=model43self.device=device

#

Retro index

46self.index=RetroIndex()

#

Retrieve nearest neighbors of a given chunk

48defretrieve\_nearest\_neighbours(self,chunk:str):

#

Retrieve the offsets of the nearest neighbors

54neighbor\_offsets=self.index([chunk],None)

#

Get the neighbors (with neighbor length equal to chunk_len * 2 )

57text=self.tds.train58neighbors=[text[j:j+self.chunk\_len\*2]forjinneighbor\_offsets[0]]

#

61returnneighbors

#

Sample text from the given prompt

63defsample(self,prompt:str,sample\_len:int):

#

To store nearest neighbors as strings

69neighbors\_str=[]

#

Sampled text

72sampled=''

#

Sample sample_len tokens

75foriinrange(sample\_len):

#

We need to retrieve neighbors, if there are more sampled chunks than we have already retrieved for

78whilelen(neighbors\_str)\<len(prompt)//self.chunk\_len:

#

Get the last chunk for which we haven't retrieved neighbors

80off=len(neighbors\_str)\*self.chunk\_len81chunk=prompt[off:off+self.chunk\_len]

#

Retrieve nearest neighbors

83neighbors\_str.append(self.retrieve\_nearest\_neighbours(chunk))

#

Tokenize the input

86src=self.tds.text\_to\_i(prompt)

#

Tokenize the retrieved neighbors

88neighbors=torch.stack([torch.stack([self.tds.text\_to\_i(n)forninchunk])forchunkinneighbors\_str])

#

Move them to the same device as the model

91src=src.to(self.device)92neighbors=neighbors.to(self.device)

#

Get model output

95res=self.model(src[None,:],neighbors[None,:,:,:])

#

Greedily sample the last token

98token=res[0,-1,:].argmax(dim=-1)

#

Add the sampled token text to the prompt and sample text

101prompt+=self.tds.itos[token.item()]102sampled+=self.tds.itos[token.item()]

#

105returnsampled

#

Retro trainer

108classTrainer:

#

113def\_\_init\_\_(self,device:torch.device,model:retro.RetroModel,114dataloader:DataLoader,optimizer:torch.optim.Optimizer):

#

121self.optimizer=optimizer122self.device=device123self.dataloader=dataloader124self.model=model125self.loss\_func=nn.CrossEntropyLoss()

#

Train the model for an epoch

127def\_\_call\_\_(self):

#

Iterate through training data

133fori,(src,tgt,neighbors)inmonit.enum('Train',self.dataloader):

#

Move data to the device

135src,tgt,neighbors=src.to(self.device),tgt.to(self.device),neighbors.to(self.device)

#

Forward pass

138res=self.model(src,neighbors)

#

Calculate loss

140loss=self.loss\_func(res.view(-1,res.shape[-1]),tgt.view(-1))

#

Clear the gradients

143self.optimizer.zero\_grad()

#

Backward pass

145loss.backward()

#

Optimize the model

147self.optimizer.step()

#

Save training statistics and increment the global step counter

150tracker.save({'loss.train':loss})151tracker.add\_global\_step(len(src))

#

Create and train a small model

154deftrain():

#

Create an experiment

160experiment.create(name='retro\_small')

#

GPU device

163device=torch.device('cuda:0')

#

Load Tiny Shakespeare dataset

166tds=TextFileDataset(167lab.get\_data\_path()/'tiny\_shakespeare.txt',168list,169url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')

#

Load Retro dataset

172train\_dataset=Dataset(lab.get\_data\_path()/'retro\_train\_dataset.json',tds)

#

Create dataloader

175train\_dl=DataLoader(train\_dataset,176batch\_size=4,177sampler=RandomSampler(train\_dataset,replacement=True))

#

Hyper-parameters

180chunk\_len=16181d\_model=128182d\_ff=512183n\_heads=16184d\_k=16

#

Create the nearest neighbor encoder

187nearest\_neighbor\_encoder=NearestNeighborEncoder(chunk\_len,6,{3},d\_model,n\_heads,d\_k,d\_ff)

#

Create the model

189model=RetroModel(tds.n\_tokens,d\_model,6,190{3,5},191chunk\_len,n\_heads,d\_k,d\_ff,192encoder=nearest\_neighbor\_encoder)

#

Move the model to the device

194model=model.to(device)

#

Create the optimizer

196optimizer=Noam(model.parameters(),lr=1.,d\_model=d\_model,warmup=2\_000)

#

Create the Trainer

198trainer=Trainer(device,model,train\_dl,optimizer)

#

Create the Sampler

200sampler=Sampler(device,model,tds,chunk\_len)

#

202prompt='''Second Citizen:\nOne word, good citizens.\n\nFirst Citizen:'''

#

Set models for saving and loading

205experiment.add\_pytorch\_models(model=model)

#

Start the experiment

208withexperiment.start():

#

Train for 32 epochs

210forepochinmonit.loop(32):

#

Train

212trainer()

#

Print a new line

214tracker.new\_line()

#

Sample from the prompt

216logger.log([(prompt.replace('\n','\\n\n'),Text.subtle),217(sampler.sample(prompt,128).replace('\n','\\n\n'),Text.none)])

#

Save models

#

222if\_\_name\_\_=='\_\_main\_\_':223train()

labml.ai