Back to Annotated Deep Learning Paper Implementations

රෙට්රෝපුහුණුව

docs/si/transformers/retro/train.html

latest9.7 KB
Original Source

hometransformersretro

View code on Github

#

රෙට්රෝපුහුණුව

මෙය RETROසඳහා පුහුණු කේතය වේ.

16importtorch17fromtorchimportnn18fromtorch.utils.dataimportDataLoader,RandomSampler1920fromlabmlimportmonit,lab,tracker,experiment,logger21fromlabml.loggerimportText22fromlabml\_helpers.datasets.textimportTextFileDataset23fromlabml\_nn.optimizers.noamimportNoam24fromlabml\_nn.transformers.retroimportmodelasretro25fromlabml\_nn.transformers.retro.datasetimportDataset,RetroIndex26fromlabml\_nn.transformers.retro.modelimportRetroModel,NearestNeighborEncoder

#

නියැදිකරු

මෙමපන්තිය ආකෘතියකින් කෑදරකමින් සාම්පල.

29classSampler:

#

  • device ආකෘතියේ උපාංගය වේ
  • modelරෙට්රෝ ප්රකාරය
  • tds යනු පෙළ දත්ත සමුදාය (අසල්වැසියා කුට්ටි ලබා ගැනීමට භාවිතා කරයි)
  • chunk_len යනු කුට්ටියක දිග
36def\_\_init\_\_(self,device:torch.device,model:retro.RetroModel,tds:TextFileDataset,chunk\_len:int):

#

43self.chunk\_len=chunk\_len44self.tds=tds45self.model=model46self.device=device

#

රෙට්රෝ දර්ශකය

49self.index=RetroIndex()

#

ලබාදී ඇති කුට්ටියක ළඟම අසල්වැසියන් ලබා ගන්න

51defretrieve\_nearest\_neighbours(self,chunk:str):

#

ළඟමඅසල්වාසීන්ගේ හිලව් ලබා ගන්න

57neighbor\_offsets=self.index([chunk],None)

#

අසල්වැසියන්ලබා ගන්න (අසල්වැසියාගේ දිග සමාන chunk_len * 2 )

60text=self.tds.train61neighbors=[text[j:j+self.chunk\_len\*2]forjinneighbor\_offsets[0]]

#

64returnneighbors

#

ලබාදී ඇති විමසුමෙන් නියැදි පෙළ

66defsample(self,prompt:str,sample\_len:int):

#

ආසන්නතමඅසල්වැසියන් නූල් ලෙස ගබඩා කිරීම

72neighbors\_str=[]

#

නියැදිපෙළ

75sampled=''

#

නියැදි sample_len ටෝකන

78foriinrange(sample\_len):

#

අපදැනටමත් ලබා ගෙන ඇති ප්රමාණයට වඩා වැඩි නියැදි කුට්ටි තිබේ නම්, අසල්වැසියන් නැවත ලබා ගත යුතුය

81whilelen(neighbors\_str)\<len(prompt)//self.chunk\_len:

#

අපඅසල්වැසියන් ලබා නොගත් අවසාන කුට්ටිය ලබා ගන්න

83off=len(neighbors\_str)\*self.chunk\_len84chunk=prompt[off:off+self.chunk\_len]

#

ළඟමඅසල්වැසියන් ලබා ගන්න

86neighbors\_str.append(self.retrieve\_nearest\_neighbours(chunk))

#

ආදානයටෝකෙන්කරණය කරන්න

89src=self.tds.text\_to\_i(prompt)

#

ලබාගත් අසල්වැසියන් ටෝකීස් කරන්න

91neighbors=torch.stack([torch.stack([self.tds.text\_to\_i(n)forninchunk])forchunkinneighbors\_str])

#

ආකෘතියටසමාන උපාංගයකට ඒවා ගෙනයන්න

94src=src.to(self.device)95neighbors=neighbors.to(self.device)

#

ආදර්ශප්රතිදානය ලබා ගන්න

98res=self.model(src[None,:],neighbors[None,:,:,:])

#

කෑදරකමින්අවසාන ටෝකනය සාම්පලය

101token=res[0,-1,:].argmax(dim=-1)

#

නියැදිටෝකන් පෙළ විමසුමට සහ නියැදි පෙළට එක් කරන්න

104prompt+=self.tds.itos[token.item()]105sampled+=self.tds.itos[token.item()]

#

108returnsampled

#

රෙට්රොපුහුණුකරු

111classTrainer:

#

116def\_\_init\_\_(self,device:torch.device,model:retro.RetroModel,117dataloader:DataLoader,optimizer:torch.optim.Optimizer):

#

124self.optimizer=optimizer125self.device=device126self.dataloader=dataloader127self.model=model128self.loss\_func=nn.CrossEntropyLoss()

#

එපෝච්සඳහා ආකෘතිය පුහුණු කරන්න

130def\_\_call\_\_(self):

#

පුහුණුදත්ත හරහා නැවත භාවිතා කරන්න

136fori,(src,tgt,neighbors)inmonit.enum('Train',self.dataloader):

#

උපාංගයවෙත දත්ත ගෙනයන්න

138src,tgt,neighbors=src.to(self.device),tgt.to(self.device),neighbors.to(self.device)

#

ඉදිරිසාමාර්ථය

141res=self.model(src,neighbors)

#

අලාභයගණනය කරන්න

143loss=self.loss\_func(res.view(-1,res.shape[-1]),tgt.view(-1))

#

අනුක්රමිකඉවත්

146self.optimizer.zero\_grad()

#

පසුගාමීපාස්

148loss.backward()

#

ආකෘතියප්රශස්ත කරන්න

150self.optimizer.step()

#

පුහුණුසංඛ්යාලේඛන සුරකින්න සහ ගෝලීය පියවර කවුන්ටරය වැඩි කරන්න

153tracker.save({'loss.train':loss})154tracker.add\_global\_step(len(src))

#

කුඩාආකෘතියක් සාදන්න සහ පුහුණු කරන්න

157deftrain():

#

අත්හදාබැලීමක් සාදන්න

163experiment.create(name='retro\_small')

#

GPUඋපාංගය

166device=torch.device('cuda:0')

#

කුඩාෂේක්ස්පියර් දත්ත කට්ටලය පටවන්න

169tds=TextFileDataset(170lab.get\_data\_path()/'tiny\_shakespeare.txt',171list,172url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')

#

රෙට්රෝ දත්ත කට්ටලය පටවන්න

175train\_dataset=Dataset(lab.get\_data\_path()/'retro\_train\_dataset.json',tds)

#

දත්තකාරකය සාදන්න

178train\_dl=DataLoader(train\_dataset,179batch\_size=4,180sampler=RandomSampler(train\_dataset,replacement=True))

#

අධිපරාමිතීන්

183chunk\_len=16184d\_model=128185d\_ff=512186n\_heads=16187d\_k=16

#

ළඟමඅසල්වැසියාගේ එන්කෝඩරය සාදන්න

190nearest\_neighbor\_encoder=NearestNeighborEncoder(chunk\_len,6,{3},d\_model,n\_heads,d\_k,d\_ff)

#

ආකෘතියසාදන්න

192model=RetroModel(tds.n\_tokens,d\_model,6,193{3,5},194chunk\_len,n\_heads,d\_k,d\_ff,195encoder=nearest\_neighbor\_encoder)

#

උපාංගයවෙත ආකෘතිය ගෙනයන්න

197model=model.to(device)

#

ප්රශස්තකරණයසාදන්න

199optimizer=Noam(model.parameters(),lr=1.,d\_model=d\_model,warmup=2\_000)

#

සාදන්න Trainer

201trainer=Trainer(device,model,train\_dl,optimizer)

#

සාදන්න Sampler

203sampler=Sampler(device,model,tds,chunk\_len)

#

205prompt='''Second Citizen:\nOne word, good citizens.\n\nFirst Citizen:'''

#

ඉතිරිකිරීම සහ පැටවීම සඳහා ආකෘති සකසන්න

208experiment.add\_pytorch\_models(model=model)

#

අත්හදාබැලීම ආරම්භ කරන්න

211withexperiment.start():

#

32 Epochs සඳහා දුම්රිය

213forepochinmonit.loop(32):

#

දුම්රිය

215trainer()

#

නවරේඛාවක් මුද්රණය කරන්න

217tracker.new\_line()

#

වෙතින්නියැදිය prompt

219logger.log([(prompt.replace('\n','\\n\n'),Text.subtle),220(sampler.sample(prompt,128).replace('\n','\\n\n'),Text.none)])

#

ආකෘතිසුරකින්න

222experiment.save\_checkpoint()

#

226if\_\_name\_\_=='\_\_main\_\_':227train()

Trending Research Paperslabml.ai