docs/transformers/retro/dataset.html
We pre-retrieve nearest neighbors from the key-value database and create the dataset to train the RETROmodel.
15importjson16frompathlibimportPath1718importnumpyasnp19importtorch20fromtorch.utils.dataimportDatasetasPyTorchDataset2122fromlabmlimportlab,monit23fromlabml\_nn.helpers.datasetsimportTextFileDataset,TextDataset24fromlabml\_nn.transformers.retro.databaseimportRetroIndex
chunk_len is the chunk lengthchunks_per_sample is the number of chunks per training sampleskip_range is the maximum number of characters to skip between two samples. We skip a few characters between samples to make sure the samples aren't aligned perfectly with the chunks in the database27defbuild\_dataset(chunk\_len:int=16,chunks\_per\_sample:int=32,skip\_range:int=8):
Load the text file
39dataset=TextFileDataset(40lab.get\_data\_path()/'tiny\_shakespeare.txt',41list,42url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
Training portion of it
45text=dataset.train
Load the index for retrieving neighbors
48index=RetroIndex()
The input sample offsets
51sample\_offsets=[]
Cursor for the text
53i=054whilei\<len(text):
Skip a few characters to make sure it's not aligned with the neighbors
56skip=np.random.randint(skip\_range)57i+=skip
Stop if we've reached the end of the text
60ifi+chunks\_per\_sample\*chunk\_len\>len(text):61break
Collect the offset
64sample\_offsets.append(i)
Increment the cursor
67i+=chunks\_per\_sample\*chunk\_len
For samples
70samples=[]
Iterate through sample offsets
72foriinmonit.iterate('Gather Neighbors',sample\_offsets):
Get the sample including an extra character (for prediction)
74sample=text[i:i+chunks\_per\_sample\*chunk\_len+1]
The input
76src=sample[:-1]
Break it into chunks
78chunks=[src[j:j+chunk\_len]forjinrange(0,len(src),chunk\_len)]
The chunk offsets
80chunk\_offsets=[j+iforjinrange(0,len(src),chunk\_len)]
Retrieve nearest neighbors
83neighbor\_offsets=index(chunks,chunk\_offsets)
Get neighbor texts. The neighbor length is twice the chunk_len
86neighbors=[[text[j:j+chunk\_len\*2]forjinn\_off]forn\_offinneighbor\_offsets]
Add to list of samples
89samples.append((sample[:-1],sample[1:],neighbors))
Save the samples in JSON. We don't need to use complex dataset storage mechanisms or pre-tokenize since our dataset is small.
94withopen(str(lab.get\_data\_path()/'retro\_train\_dataset.json'),'w')asf:95f.write(json.dumps(samples))
This is the PyTorch dataset that loads the dataset created by build_dataset .
98classDataset(PyTorchDataset):
file_path is the path of the saved JSON filetds is the TextDataset105def\_\_init\_\_(self,file\_path:Path,tds:TextDataset):
111self.tds=tds
Load the samples
113withopen(str(file\_path),'r')asf:114self.samples=json.loads(f.read())
Number of samples
116def\_\_len\_\_(self):
120returnlen(self.samples)
Get a sample
122def\_\_getitem\_\_(self,idx:int):
Get the sample
127s=self.samples[idx]
Tokenize
129src=self.tds.text\_to\_i(s[0])130tgt=self.tds.text\_to\_i(s[1])131neighbors=torch.stack([torch.stack([self.tds.text\_to\_i(n)forninchunks])forchunksins[2]])
133returnsrc,tgt,neighbors
136if\_\_name\_\_=='\_\_main\_\_':137build\_dataset()