docs/transformers/retro/train.html
This is the training code for RETRO.
14importtorch15fromlabmlimportmonit,lab,tracker,experiment,logger16fromlabml.loggerimportText17fromlabml\_nn.helpers.datasetsimportTextFileDataset18fromlabml\_nn.optimizers.noamimportNoam19fromlabml\_nn.transformers.retroimportmodelasretro20fromlabml\_nn.transformers.retro.datasetimportDataset,RetroIndex21fromlabml\_nn.transformers.retro.modelimportRetroModel,NearestNeighborEncoder22fromtorchimportnn23fromtorch.utils.dataimportDataLoader,RandomSampler
This class greedily samples from a model.
26classSampler:
device is the device of the modelmodel is the Retro modetds is the text dataset (used to get neighbor chunks)chunk_len is the length of a chunk33def\_\_init\_\_(self,device:torch.device,model:retro.RetroModel,tds:TextFileDataset,chunk\_len:int):
40self.chunk\_len=chunk\_len41self.tds=tds42self.model=model43self.device=device
46self.index=RetroIndex()
48defretrieve\_nearest\_neighbours(self,chunk:str):
Retrieve the offsets of the nearest neighbors
54neighbor\_offsets=self.index([chunk],None)
Get the neighbors (with neighbor length equal to chunk_len * 2 )
57text=self.tds.train58neighbors=[text[j:j+self.chunk\_len\*2]forjinneighbor\_offsets[0]]
61returnneighbors
63defsample(self,prompt:str,sample\_len:int):
To store nearest neighbors as strings
69neighbors\_str=[]
Sampled text
72sampled=''
Sample sample_len tokens
75foriinrange(sample\_len):
We need to retrieve neighbors, if there are more sampled chunks than we have already retrieved for
78whilelen(neighbors\_str)\<len(prompt)//self.chunk\_len:
Get the last chunk for which we haven't retrieved neighbors
80off=len(neighbors\_str)\*self.chunk\_len81chunk=prompt[off:off+self.chunk\_len]
Retrieve nearest neighbors
83neighbors\_str.append(self.retrieve\_nearest\_neighbours(chunk))
Tokenize the input
86src=self.tds.text\_to\_i(prompt)
Tokenize the retrieved neighbors
88neighbors=torch.stack([torch.stack([self.tds.text\_to\_i(n)forninchunk])forchunkinneighbors\_str])
Move them to the same device as the model
91src=src.to(self.device)92neighbors=neighbors.to(self.device)
Get model output
95res=self.model(src[None,:],neighbors[None,:,:,:])
Greedily sample the last token
98token=res[0,-1,:].argmax(dim=-1)
Add the sampled token text to the prompt and sample text
101prompt+=self.tds.itos[token.item()]102sampled+=self.tds.itos[token.item()]
105returnsampled
108classTrainer:
device is the device of the modelmodel is the Retro modedataloader is the dataloader for the dataset with pre-retrieved neighborsoptimizer is the optimizer113def\_\_init\_\_(self,device:torch.device,model:retro.RetroModel,114dataloader:DataLoader,optimizer:torch.optim.Optimizer):
121self.optimizer=optimizer122self.device=device123self.dataloader=dataloader124self.model=model125self.loss\_func=nn.CrossEntropyLoss()
127def\_\_call\_\_(self):
Iterate through training data
133fori,(src,tgt,neighbors)inmonit.enum('Train',self.dataloader):
Move data to the device
135src,tgt,neighbors=src.to(self.device),tgt.to(self.device),neighbors.to(self.device)
Forward pass
138res=self.model(src,neighbors)
Calculate loss
140loss=self.loss\_func(res.view(-1,res.shape[-1]),tgt.view(-1))
Clear the gradients
143self.optimizer.zero\_grad()
Backward pass
145loss.backward()
Optimize the model
147self.optimizer.step()
Save training statistics and increment the global step counter
150tracker.save({'loss.train':loss})151tracker.add\_global\_step(len(src))
154deftrain():
Create an experiment
160experiment.create(name='retro\_small')
GPU device
163device=torch.device('cuda:0')
Load Tiny Shakespeare dataset
166tds=TextFileDataset(167lab.get\_data\_path()/'tiny\_shakespeare.txt',168list,169url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
Load Retro dataset
172train\_dataset=Dataset(lab.get\_data\_path()/'retro\_train\_dataset.json',tds)
Create dataloader
175train\_dl=DataLoader(train\_dataset,176batch\_size=4,177sampler=RandomSampler(train\_dataset,replacement=True))
Hyper-parameters
180chunk\_len=16181d\_model=128182d\_ff=512183n\_heads=16184d\_k=16
Create the nearest neighbor encoder
187nearest\_neighbor\_encoder=NearestNeighborEncoder(chunk\_len,6,{3},d\_model,n\_heads,d\_k,d\_ff)
Create the model
189model=RetroModel(tds.n\_tokens,d\_model,6,190{3,5},191chunk\_len,n\_heads,d\_k,d\_ff,192encoder=nearest\_neighbor\_encoder)
Move the model to the device
194model=model.to(device)
Create the optimizer
196optimizer=Noam(model.parameters(),lr=1.,d\_model=d\_model,warmup=2\_000)
Create the Trainer
198trainer=Trainer(device,model,train\_dl,optimizer)
Create the Sampler
200sampler=Sampler(device,model,tds,chunk\_len)
202prompt='''Second Citizen:\nOne word, good citizens.\n\nFirst Citizen:'''
Set models for saving and loading
205experiment.add\_pytorch\_models(model=model)
Start the experiment
208withexperiment.start():
Train for 32 epochs
210forepochinmonit.loop(32):
Train
212trainer()
Print a new line
214tracker.new\_line()
Sample from the prompt
216logger.log([(prompt.replace('\n','\\n\n'),Text.subtle),217(sampler.sample(prompt,128).replace('\n','\\n\n'),Text.none)])
Save models
222if\_\_name\_\_=='\_\_main\_\_':223train()