docs/unet/experiment.html
This trains a U-Net model on Carvana dataset. You can find the download instructions on Kaggle.
Save the training images inside carvana/train folder and the masks in carvana/train_masks folder.
For simplicity, we do not do a training and validation split.
19importnumpyasnp20importtorchvision.transforms.functional2122importtorch23importtorch.utils.data24fromlabmlimportlab,tracker,experiment,monit25fromlabml.configsimportBaseConfigs26fromlabml\_nn.helpers.deviceimportDeviceConfigs27fromlabml\_nn.unetimportUNet28fromlabml\_nn.unet.carvanaimportCarvanaDataset29fromtorchimportnn
32classConfigs(BaseConfigs):
Device to train the model on. DeviceConfigs picks up an available CUDA device or defaults to CPU.
39device:torch.device=DeviceConfigs()
U-Net model
42model:UNet
Number of channels in the image. 3 for RGB.
45image\_channels:int=3
Number of channels in the output mask. 1 for binary mask.
47mask\_channels:int=1
Batch size
50batch\_size:int=1
Learning rate
52learning\_rate:float=2.5e-4
Number of training epochs
55epochs:int=4
Dataset
58dataset:CarvanaDataset
Dataloader
60data\_loader:torch.utils.data.DataLoader
Loss function
63loss\_func=nn.BCELoss()
Sigmoid function for binary classification
65sigmoid=nn.Sigmoid()
Adam optimizer
68optimizer:torch.optim.Adam
70definit(self):
Initialize the Carvana dataset
72self.dataset=CarvanaDataset(lab.get\_data\_path()/'carvana'/'train',73lab.get\_data\_path()/'carvana'/'train\_masks')
Initialize the model
75self.model=UNet(self.image\_channels,self.mask\_channels).to(self.device)
Create dataloader
78self.data\_loader=torch.utils.data.DataLoader(self.dataset,self.batch\_size,79shuffle=True,pin\_memory=True)
Create optimizer
81self.optimizer=torch.optim.Adam(self.model.parameters(),lr=self.learning\_rate)
Image logging
84tracker.set\_image("sample",True)
[email protected]\_grad()87defsample(self,idx=-1):
Get a random sample
93x,\_=self.dataset[np.random.randint(len(self.dataset))]
Move data to device
95x=x.to(self.device)
Get predicted mask
98mask=self.sigmoid(self.model(x[None,:]))
Crop the image to the size of the mask
100x=torchvision.transforms.functional.center\_crop(x,[mask.shape[2],mask.shape[3]])
Log samples
102tracker.save('sample',x\*mask)
104deftrain(self):
Iterate through the dataset. Use mix to sample 50 times per epoch.
112for\_,(image,mask)inmonit.mix(('Train',self.data\_loader),(self.sample,list(range(50)))):
Increment global step
114tracker.add\_global\_step()
Move data to device
116image,mask=image.to(self.device),mask.to(self.device)
Make the gradients zero
119self.optimizer.zero\_grad()
Get predicted mask logits
121logits=self.model(image)
Crop the target mask to the size of the logits. Size of the logits will be smaller if we don't use padding in convolutional layers in the U-Net.
124mask=torchvision.transforms.functional.center\_crop(mask,[logits.shape[2],logits.shape[3]])
Calculate loss
126loss=self.loss\_func(self.sigmoid(logits),mask)
Compute gradients
128loss.backward()
Take an optimization step
130self.optimizer.step()
Track the loss
132tracker.save('loss',loss)
134defrun(self):
138for\_inmonit.loop(self.epochs):
Train the model
140self.train()
New line in the console
142tracker.new\_line()
Save the model
146defmain():
Create experiment
148experiment.create(name='unet')
Create configurations
151configs=Configs()
Set configurations. You can override the defaults by passing the values in the dictionary.
154experiment.configs(configs,{})
Initialize
157configs.init()
Set models for saving and loading
160experiment.add\_pytorch\_models({'model':configs.model})
Start and run the training loop
163withexperiment.start():164configs.run()
168if\_\_name\_\_=='\_\_main\_\_':169main()