docs/neox/utils/text_dataset.html
10frompathlibimportPurePath,Path11fromtypingimportOptional,List1213importtorch14importtorch.utils.data15fromlabmlimportlab16fromlabmlimportmonit17fromlabml.loggerimportinspect18fromlabml.utils.downloadimportdownload\_file1920fromlabml\_nn.neox.tokenizerimportget\_tokenizer
path is the location of the text fileurl is the URL to download the file fromfilter_subset is the number of characters to filter. Use this during testing when trying large datasetsReturns the text content
23defload\_text(path:PurePath,url:Optional[str]=None,\*,filter\_subset:Optional[int]=None):
34path=Path(path)
Download if it doesn't exist
37ifnotpath.exists():38ifnoturl:39raiseFileNotFoundError(str(path))40else:41download\_file(url,path)4243withmonit.section("Load data"):
Load data
45withopen(str(path),'r')asf:46text=f.read()
Filter
48iffilter\_subset:49text=text[:filter\_subset]
52returntext
This is not optimized to very large datasets.
55classNeoXDataset(torch.utils.data.Dataset):
tokens is the list of token idsseq_len is the sequence length of a single training sample62def\_\_init\_\_(self,tokens:List[int],seq\_len:int):
68self.seq\_len=seq\_len
Number of samples
70n\_samples=len(tokens)//seq\_len71self.n\_samples=n\_samples
Truncate
73tokens=tokens[:n\_samples\*seq\_len+1]
Create a PyTorch tensor
75self.tokens=torch.tensor(tokens)
77def\_\_len\_\_(self):78returnself.n\_samples
idx is the index of the sampleReturns the input and the target
80def\_\_getitem\_\_(self,idx:int):
87offset=idx\*self.seq\_len88returnself.tokens[offset:offset+self.seq\_len],self.tokens[offset+1:offset+1+self.seq\_len]899091DATASETS={92'tiny\_shakespeare':{93'file':'tiny\_shakespeare.txt',94'url':'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'95}96}
seq_len is the sequence length of a single training sampledataset_name is the name of the datasetReturns the dataset
99defget\_training\_data(seq\_len:int=32,dataset\_name:str='tiny\_shakespeare',truncate:int=-1):
108ds=DATASETS[dataset\_name]
Load the content
110text=load\_text(lab.get\_data\_path()/ds['file'],ds['url'])
Tokenize
112tokenizer=get\_tokenizer()113tokens=tokenizer.encode\_batch([text])[0]114115iftruncate\>0:116token\_ids=tokens.ids[:truncate\*seq\_len]117else:118token\_ids=tokens.ids
121returnNeoXDataset(token\_ids,seq\_len)
124def\_test():125dataset=get\_training\_data()126127inspect(tokens=len(dataset.tokens))
131if\_\_name\_\_=='\_\_main\_\_':132\_test()