docs/si/transformers/retro/train.html
මෙය RETROසඳහා පුහුණු කේතය වේ.
16importtorch17fromtorchimportnn18fromtorch.utils.dataimportDataLoader,RandomSampler1920fromlabmlimportmonit,lab,tracker,experiment,logger21fromlabml.loggerimportText22fromlabml\_helpers.datasets.textimportTextFileDataset23fromlabml\_nn.optimizers.noamimportNoam24fromlabml\_nn.transformers.retroimportmodelasretro25fromlabml\_nn.transformers.retro.datasetimportDataset,RetroIndex26fromlabml\_nn.transformers.retro.modelimportRetroModel,NearestNeighborEncoder
මෙමපන්තිය ආකෘතියකින් කෑදරකමින් සාම්පල.
29classSampler:
device ආකෘතියේ උපාංගය වේmodelරෙට්රෝ ප්රකාරයtds යනු පෙළ දත්ත සමුදාය (අසල්වැසියා කුට්ටි ලබා ගැනීමට භාවිතා කරයි)chunk_len යනු කුට්ටියක දිග36def\_\_init\_\_(self,device:torch.device,model:retro.RetroModel,tds:TextFileDataset,chunk\_len:int):
43self.chunk\_len=chunk\_len44self.tds=tds45self.model=model46self.device=device
49self.index=RetroIndex()
51defretrieve\_nearest\_neighbours(self,chunk:str):
ළඟමඅසල්වාසීන්ගේ හිලව් ලබා ගන්න
57neighbor\_offsets=self.index([chunk],None)
අසල්වැසියන්ලබා ගන්න (අසල්වැසියාගේ දිග සමාන chunk_len * 2 )
60text=self.tds.train61neighbors=[text[j:j+self.chunk\_len\*2]forjinneighbor\_offsets[0]]
64returnneighbors
66defsample(self,prompt:str,sample\_len:int):
ආසන්නතමඅසල්වැසියන් නූල් ලෙස ගබඩා කිරීම
72neighbors\_str=[]
නියැදිපෙළ
75sampled=''
නියැදි sample_len ටෝකන
78foriinrange(sample\_len):
අපදැනටමත් ලබා ගෙන ඇති ප්රමාණයට වඩා වැඩි නියැදි කුට්ටි තිබේ නම්, අසල්වැසියන් නැවත ලබා ගත යුතුය
81whilelen(neighbors\_str)\<len(prompt)//self.chunk\_len:
අපඅසල්වැසියන් ලබා නොගත් අවසාන කුට්ටිය ලබා ගන්න
83off=len(neighbors\_str)\*self.chunk\_len84chunk=prompt[off:off+self.chunk\_len]
ළඟමඅසල්වැසියන් ලබා ගන්න
86neighbors\_str.append(self.retrieve\_nearest\_neighbours(chunk))
ආදානයටෝකෙන්කරණය කරන්න
89src=self.tds.text\_to\_i(prompt)
ලබාගත් අසල්වැසියන් ටෝකීස් කරන්න
91neighbors=torch.stack([torch.stack([self.tds.text\_to\_i(n)forninchunk])forchunkinneighbors\_str])
ආකෘතියටසමාන උපාංගයකට ඒවා ගෙනයන්න
94src=src.to(self.device)95neighbors=neighbors.to(self.device)
ආදර්ශප්රතිදානය ලබා ගන්න
98res=self.model(src[None,:],neighbors[None,:,:,:])
කෑදරකමින්අවසාන ටෝකනය සාම්පලය
101token=res[0,-1,:].argmax(dim=-1)
නියැදිටෝකන් පෙළ විමසුමට සහ නියැදි පෙළට එක් කරන්න
104prompt+=self.tds.itos[token.item()]105sampled+=self.tds.itos[token.item()]
108returnsampled
111classTrainer:
device ආකෘතියේ උපාංගය වේmodelරෙට්රෝ ප්රකාරයdataloaderපෙර ලබා ගත් අසල්වැසියන් සමඟ දත්ත කට්ටලය සඳහා දත්ත සමුදාය වේoptimizer ප්රශස්තකරණය වේ116def\_\_init\_\_(self,device:torch.device,model:retro.RetroModel,117dataloader:DataLoader,optimizer:torch.optim.Optimizer):
124self.optimizer=optimizer125self.device=device126self.dataloader=dataloader127self.model=model128self.loss\_func=nn.CrossEntropyLoss()
130def\_\_call\_\_(self):
පුහුණුදත්ත හරහා නැවත භාවිතා කරන්න
136fori,(src,tgt,neighbors)inmonit.enum('Train',self.dataloader):
උපාංගයවෙත දත්ත ගෙනයන්න
138src,tgt,neighbors=src.to(self.device),tgt.to(self.device),neighbors.to(self.device)
ඉදිරිසාමාර්ථය
141res=self.model(src,neighbors)
අලාභයගණනය කරන්න
143loss=self.loss\_func(res.view(-1,res.shape[-1]),tgt.view(-1))
අනුක්රමිකඉවත්
146self.optimizer.zero\_grad()
පසුගාමීපාස්
148loss.backward()
ආකෘතියප්රශස්ත කරන්න
150self.optimizer.step()
පුහුණුසංඛ්යාලේඛන සුරකින්න සහ ගෝලීය පියවර කවුන්ටරය වැඩි කරන්න
153tracker.save({'loss.train':loss})154tracker.add\_global\_step(len(src))
157deftrain():
අත්හදාබැලීමක් සාදන්න
163experiment.create(name='retro\_small')
GPUඋපාංගය
166device=torch.device('cuda:0')
කුඩාෂේක්ස්පියර් දත්ත කට්ටලය පටවන්න
169tds=TextFileDataset(170lab.get\_data\_path()/'tiny\_shakespeare.txt',171list,172url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
රෙට්රෝ දත්ත කට්ටලය පටවන්න
175train\_dataset=Dataset(lab.get\_data\_path()/'retro\_train\_dataset.json',tds)
දත්තකාරකය සාදන්න
178train\_dl=DataLoader(train\_dataset,179batch\_size=4,180sampler=RandomSampler(train\_dataset,replacement=True))
අධිපරාමිතීන්
183chunk\_len=16184d\_model=128185d\_ff=512186n\_heads=16187d\_k=16
ළඟමඅසල්වැසියාගේ එන්කෝඩරය සාදන්න
190nearest\_neighbor\_encoder=NearestNeighborEncoder(chunk\_len,6,{3},d\_model,n\_heads,d\_k,d\_ff)
ආකෘතියසාදන්න
192model=RetroModel(tds.n\_tokens,d\_model,6,193{3,5},194chunk\_len,n\_heads,d\_k,d\_ff,195encoder=nearest\_neighbor\_encoder)
උපාංගයවෙත ආකෘතිය ගෙනයන්න
197model=model.to(device)
ප්රශස්තකරණයසාදන්න
199optimizer=Noam(model.parameters(),lr=1.,d\_model=d\_model,warmup=2\_000)
සාදන්න Trainer
201trainer=Trainer(device,model,train\_dl,optimizer)
සාදන්න Sampler
203sampler=Sampler(device,model,tds,chunk\_len)
205prompt='''Second Citizen:\nOne word, good citizens.\n\nFirst Citizen:'''
ඉතිරිකිරීම සහ පැටවීම සඳහා ආකෘති සකසන්න
208experiment.add\_pytorch\_models(model=model)
අත්හදාබැලීම ආරම්භ කරන්න
211withexperiment.start():
32 Epochs සඳහා දුම්රිය
213forepochinmonit.loop(32):
දුම්රිය
215trainer()
නවරේඛාවක් මුද්රණය කරන්න
217tracker.new\_line()
වෙතින්නියැදිය prompt
219logger.log([(prompt.replace('\n','\\n\n'),Text.subtle),220(sampler.sample(prompt,128).replace('\n','\\n\n'),Text.none)])
ආකෘතිසුරකින්න
222experiment.save\_checkpoint()
226if\_\_name\_\_=='\_\_main\_\_':227train()