Back to Annotated Deep Learning Paper Implementations

RETRO training dataset

docs/transformers/retro/dataset.html

latest4.3 KB
Original Source

hometransformersretro

View code on Github

#

RETRO training dataset

We pre-retrieve nearest neighbors from the key-value database and create the dataset to train the RETROmodel.

15importjson16frompathlibimportPath1718importnumpyasnp19importtorch20fromtorch.utils.dataimportDatasetasPyTorchDataset2122fromlabmlimportlab,monit23fromlabml\_nn.helpers.datasetsimportTextFileDataset,TextDataset24fromlabml\_nn.transformers.retro.databaseimportRetroIndex

#

Build the dataset

  • chunk_len is the chunk length
  • chunks_per_sample is the number of chunks per training sample
  • skip_range is the maximum number of characters to skip between two samples. We skip a few characters between samples to make sure the samples aren't aligned perfectly with the chunks in the database
27defbuild\_dataset(chunk\_len:int=16,chunks\_per\_sample:int=32,skip\_range:int=8):

#

Load the text file

39dataset=TextFileDataset(40lab.get\_data\_path()/'tiny\_shakespeare.txt',41list,42url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')

#

Training portion of it

45text=dataset.train

#

Load the index for retrieving neighbors

48index=RetroIndex()

#

The input sample offsets

51sample\_offsets=[]

#

Cursor for the text

53i=054whilei\<len(text):

#

Skip a few characters to make sure it's not aligned with the neighbors

56skip=np.random.randint(skip\_range)57i+=skip

#

Stop if we've reached the end of the text

60ifi+chunks\_per\_sample\*chunk\_len\>len(text):61break

#

Collect the offset

64sample\_offsets.append(i)

#

Increment the cursor

67i+=chunks\_per\_sample\*chunk\_len

#

For samples

70samples=[]

#

Iterate through sample offsets

72foriinmonit.iterate('Gather Neighbors',sample\_offsets):

#

Get the sample including an extra character (for prediction)

74sample=text[i:i+chunks\_per\_sample\*chunk\_len+1]

#

The input

76src=sample[:-1]

#

Break it into chunks

78chunks=[src[j:j+chunk\_len]forjinrange(0,len(src),chunk\_len)]

#

The chunk offsets

80chunk\_offsets=[j+iforjinrange(0,len(src),chunk\_len)]

#

Retrieve nearest neighbors

83neighbor\_offsets=index(chunks,chunk\_offsets)

#

Get neighbor texts. The neighbor length is twice the chunk_len

86neighbors=[[text[j:j+chunk\_len\*2]forjinn\_off]forn\_offinneighbor\_offsets]

#

Add to list of samples

89samples.append((sample[:-1],sample[1:],neighbors))

#

Save the samples in JSON. We don't need to use complex dataset storage mechanisms or pre-tokenize since our dataset is small.

94withopen(str(lab.get\_data\_path()/'retro\_train\_dataset.json'),'w')asf:95f.write(json.dumps(samples))

#

Dataset

This is the PyTorch dataset that loads the dataset created by build_dataset .

98classDataset(PyTorchDataset):

#

  • file_path is the path of the saved JSON file
  • tds is the TextDataset
105def\_\_init\_\_(self,file\_path:Path,tds:TextDataset):

#

111self.tds=tds

#

Load the samples

113withopen(str(file\_path),'r')asf:114self.samples=json.loads(f.read())

#

Number of samples

116def\_\_len\_\_(self):

#

120returnlen(self.samples)

#

Get a sample

122def\_\_getitem\_\_(self,idx:int):

#

Get the sample

127s=self.samples[idx]

#

Tokenize

129src=self.tds.text\_to\_i(s[0])130tgt=self.tds.text\_to\_i(s[1])131neighbors=torch.stack([torch.stack([self.tds.text\_to\_i(n)forninchunks])forchunksins[2]])

#

133returnsrc,tgt,neighbors

#

136if\_\_name\_\_=='\_\_main\_\_':137build\_dataset()

labml.ai