docs/unet/carvana.html
You can find the download instructions on Kaggle.
Save the training images inside carvana/train folder and the masks in carvana/train_masks folder.
16frompathlibimportPath1718importtorchvision.transforms.functional19fromPILimportImage2021importtorch.utils.data22fromlabmlimportlab
25classCarvanaDataset(torch.utils.data.Dataset):
image_path is the path to the imagesmask_path is the path to the masks30def\_\_init\_\_(self,image\_path:Path,mask\_path:Path):
Get a dictionary of images by id
36self.images={p.stem:pforpinimage\_path.iterdir()}
Get a dictionary of masks by id
38self.masks={p.stem[:-5]:pforpinmask\_path.iterdir()}
Image ids list
41self.ids=list(self.images.keys())
Transformations
44self.transforms=torchvision.transforms.Compose([45torchvision.transforms.Resize(572),46torchvision.transforms.ToTensor(),47])
idx is index of the image49def\_\_getitem\_\_(self,idx:int):
Get image id
57id\_=self.ids[idx]
Load image
59image=Image.open(self.images[id\_])
Transform image and convert it to a PyTorch tensor
61image=self.transforms(image)
Load mask
63mask=Image.open(self.masks[id\_])
Transform mask and convert it to a PyTorch tensor
65mask=self.transforms(mask)
The mask values were not 1, so we scale it appropriately.
68mask=mask/mask.max()
Return the image and the mask
71returnimage,mask
73def\_\_len\_\_(self):
77returnlen(self.ids)
Testing code
81if\_\_name\_\_=='\_\_main\_\_':82ds=CarvanaDataset(lab.get\_data\_path()/'carvana'/'train',lab.get\_data\_path()/'carvana'/'train\_masks')