docs/helpers/datasets.html
1importrandom2frompathlibimportPurePath,Path3fromtypingimportList,Callable,Dict,Optional45fromtorchvisionimportdatasets,transforms67importtorch8fromlabmlimportlab9fromlabmlimportmonit10fromlabml.configsimportBaseConfigs11fromlabml.configsimportaggregate,option12fromlabml.utils.downloadimportdownload\_file13fromtorch.utils.dataimportDataLoader14fromtorch.utils.dataimportIterableDataset,Dataset
17def\_mnist\_dataset(is\_train,transform):18returndatasets.MNIST(str(lab.get\_data\_path()),19train=is\_train,20download=True,21transform=transform)
Configurable MNIST data set.
Arguments: dataset_name (str): name of the data set, MNIST dataset_transforms (torchvision.transforms.Compose): image transformations train_dataset (torchvision.datasets.MNIST): training dataset valid_dataset (torchvision.datasets.MNIST): validation dataset
train_loader (torch.utils.data.DataLoader): training data loader valid_loader (torch.utils.data.DataLoader): validation data loader
train_batch_size (int): training batch size valid_batch_size (int): validation batch size
train_loader_shuffle (bool): whether to shuffle training data valid_loader_shuffle (bool): whether to shuffle validation data
24classMNISTConfigs(BaseConfigs):
44dataset\_name:str='MNIST'45dataset\_transforms:transforms.Compose46train\_dataset:datasets.MNIST47valid\_dataset:datasets.MNIST4849train\_loader:DataLoader50valid\_loader:DataLoader5152train\_batch\_size:int=6453valid\_batch\_size:int=10245455train\_loader\_shuffle:bool=True56valid\_loader\_shuffle:bool=False
Configurable CIFAR 10 data set.
Arguments: dataset_name (str): name of the data set, CIFAR10 dataset_transforms (torchvision.transforms.Compose): image transformations train_dataset (torchvision.datasets.CIFAR10): training dataset valid_dataset (torchvision.datasets.CIFAR10): validation dataset
train_loader (torch.utils.data.DataLoader): training data loader valid_loader (torch.utils.data.DataLoader): validation data loader
train_batch_size (int): training batch size valid_batch_size (int): validation batch size
train_loader_shuffle (bool): whether to shuffle training data valid_loader_shuffle (bool): whether to shuffle validation data
59@option(MNISTConfigs.dataset\_transforms)60defmnist\_transforms():61returntransforms.Compose([62transforms.ToTensor(),63transforms.Normalize((0.1307,),(0.3081,))64])656667@option(MNISTConfigs.train\_dataset)68defmnist\_train\_dataset(c:MNISTConfigs):69return\_mnist\_dataset(True,c.dataset\_transforms)707172@option(MNISTConfigs.valid\_dataset)73defmnist\_valid\_dataset(c:MNISTConfigs):74return\_mnist\_dataset(False,c.dataset\_transforms)757677@option(MNISTConfigs.train\_loader)78defmnist\_train\_loader(c:MNISTConfigs):79returnDataLoader(c.train\_dataset,80batch\_size=c.train\_batch\_size,81shuffle=c.train\_loader\_shuffle)828384@option(MNISTConfigs.valid\_loader)85defmnist\_valid\_loader(c:MNISTConfigs):86returnDataLoader(c.valid\_dataset,87batch\_size=c.valid\_batch\_size,88shuffle=c.valid\_loader\_shuffle)899091aggregate(MNISTConfigs.dataset\_name,'MNIST',92(MNISTConfigs.dataset\_transforms,'mnist\_transforms'),93(MNISTConfigs.train\_dataset,'mnist\_train\_dataset'),94(MNISTConfigs.valid\_dataset,'mnist\_valid\_dataset'),95(MNISTConfigs.train\_loader,'mnist\_train\_loader'),96(MNISTConfigs.valid\_loader,'mnist\_valid\_loader'))979899def\_cifar\_dataset(is\_train,transform):100returndatasets.CIFAR10(str(lab.get\_data\_path()),101train=is\_train,102download=True,103transform=transform)104105106classCIFAR10Configs(BaseConfigs):
125dataset\_name:str='CIFAR10'126dataset\_transforms:transforms.Compose127train\_dataset:datasets.CIFAR10128valid\_dataset:datasets.CIFAR10129130train\_loader:DataLoader131valid\_loader:DataLoader132133train\_batch\_size:int=64134valid\_batch\_size:int=1024135136train\_loader\_shuffle:bool=True137valid\_loader\_shuffle:bool=False
[email protected](CIFAR10Configs.dataset\_transforms)141defcifar10\_transforms():142returntransforms.Compose([143transforms.ToTensor(),144transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))145])[email protected](CIFAR10Configs.train\_dataset)149defcifar10\_train\_dataset(c:CIFAR10Configs):150return\_cifar\_dataset(True,c.dataset\_transforms)[email protected](CIFAR10Configs.valid\_dataset)154defcifar10\_valid\_dataset(c:CIFAR10Configs):155return\_cifar\_dataset(False,c.dataset\_transforms)[email protected](CIFAR10Configs.train\_loader)159defcifar10\_train\_loader(c:CIFAR10Configs):160returnDataLoader(c.train\_dataset,161batch\_size=c.train\_batch\_size,162shuffle=c.train\_loader\_shuffle)[email protected](CIFAR10Configs.valid\_loader)166defcifar10\_valid\_loader(c:CIFAR10Configs):167returnDataLoader(c.valid\_dataset,168batch\_size=c.valid\_batch\_size,169shuffle=c.valid\_loader\_shuffle)170171172CIFAR10Configs.aggregate(CIFAR10Configs.dataset\_name,'CIFAR10',173(CIFAR10Configs.dataset\_transforms,'cifar10\_transforms'),174(CIFAR10Configs.train\_dataset,'cifar10\_train\_dataset'),175(CIFAR10Configs.valid\_dataset,'cifar10\_valid\_dataset'),176(CIFAR10Configs.train\_loader,'cifar10\_train\_loader'),177(CIFAR10Configs.valid\_loader,'cifar10\_valid\_loader'))178179180classTextDataset:181itos:List[str]182stoi:Dict[str,int]183n\_tokens:int184train:str185valid:str186standard\_tokens:List[str]=[]187188@staticmethod189defload(path:PurePath):190withopen(str(path),'r')asf:191returnf.read()192193def\_\_init\_\_(self,path:PurePath,tokenizer:Callable,train:str,valid:str,test:str,\*,194n\_tokens:Optional[int]=None,195stoi:Optional[Dict[str,int]]=None,196itos:Optional[List[str]]=None):197self.test=test198self.valid=valid199self.train=train200self.tokenizer=tokenizer201self.path=path202203ifn\_tokensorstoioritos:204assertstoianditosandn\_tokens205self.n\_tokens=n\_tokens206self.stoi=stoi207self.itos=itos208else:209self.n\_tokens=len(self.standard\_tokens)210self.stoi={t:ifori,tinenumerate(self.standard\_tokens)}211212withmonit.section("Tokenize"):213tokens=self.tokenizer(self.train)+self.tokenizer(self.valid)214tokens=sorted(list(set(tokens)))215216fortinmonit.iterate("Build vocabulary",tokens):217self.stoi[t]=self.n\_tokens218self.n\_tokens+=1219220self.itos=['']\*self.n\_tokens221fort,ninself.stoi.items():222self.itos[n]=t223224deftext\_to\_i(self,text:str)-\>torch.Tensor:225tokens=self.tokenizer(text)226returntorch.tensor([self.stoi[s]forsintokensifsinself.stoi],dtype=torch.long)227228def\_\_repr\_\_(self):229returnf'{len(self.train) / 1\_000\_000 :,.2f}M, {len(self.valid) / 1\_000\_000 :,.2f}M - {str(self.path)}'230231232classSequentialDataLoader(IterableDataset):233def\_\_init\_\_(self,\*,text:str,dataset:TextDataset,234batch\_size:int,seq\_len:int):235self.seq\_len=seq\_len236data=dataset.text\_to\_i(text)237n\_batch=data.shape[0]//batch\_size238data=data.narrow(0,0,n\_batch\*batch\_size)239data=data.view(batch\_size,-1).t().contiguous()240self.data=data241242def\_\_len\_\_(self):243returnself.data.shape[0]//self.seq\_len244245def\_\_iter\_\_(self):246self.idx=0247returnself248249def\_\_next\_\_(self):250ifself.idx\>=self.data.shape[0]-1:251raiseStopIteration()252253seq\_len=min(self.seq\_len,self.data.shape[0]-1-self.idx)254i=self.idx+seq\_len255data=self.data[self.idx:i]256target=self.data[self.idx+1:i+1]257self.idx=i258returndata,target259260def\_\_getitem\_\_(self,idx):261seq\_len=min(self.seq\_len,self.data.shape[0]-1-idx)262i=idx+seq\_len263data=self.data[idx:i]264target=self.data[idx+1:i+1]265returndata,target266267268classSequentialUnBatchedDataset(Dataset):269def\_\_init\_\_(self,\*,text:str,dataset:TextDataset,270seq\_len:int,271is\_random\_offset:bool=True):272self.is\_random\_offset=is\_random\_offset273self.seq\_len=seq\_len274self.data=dataset.text\_to\_i(text)275276def\_\_len\_\_(self):277return(self.data.shape[0]-1)//self.seq\_len278279def\_\_getitem\_\_(self,idx):280start=idx\*self.seq\_len281assertstart+self.seq\_len+1\<=self.data.shape[0]282ifself.is\_random\_offset:283start+=random.randint(0,min(self.seq\_len-1,self.data.shape[0]-(start+self.seq\_len+1)))284285end=start+self.seq\_len286data=self.data[start:end]287target=self.data[start+1:end+1]288returndata,target289290291classTextFileDataset(TextDataset):292standard\_tokens=[]293294def\_\_init\_\_(self,path:PurePath,tokenizer:Callable,\*,295url:Optional[str]=None,296filter\_subset:Optional[int]=None):297path=Path(path)298ifnotpath.exists():299ifnoturl:300raiseFileNotFoundError(str(path))301else:302download\_file(url,path)303304withmonit.section("Load data"):305text=self.load(path)306iffilter\_subset:307text=text[:filter\_subset]308split=int(len(text)\*.9)309train=text[:split]310valid=text[split:]311312super().\_\_init\_\_(path,tokenizer,train,valid,'')313314315def\_test\_tiny\_shakespeare():316fromlabmlimportlab317\_=TextFileDataset(lab.get\_data\_path()/'tiny\_shakespeare.txt',lambdax:list(x),318url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')319320321if\_\_name\_\_=='\_\_main\_\_':322\_test\_tiny\_shakespeare()