Back to Annotated Deep Learning Paper Implementations

Training U-Net

docs/unet/experiment.html

latest5.5 KB
Original Source

homeunet

View code on Github

#

Training U-Net

This trains a U-Net model on Carvana dataset. You can find the download instructions on Kaggle.

Save the training images inside carvana/train folder and the masks in carvana/train_masks folder.

For simplicity, we do not do a training and validation split.

19importnumpyasnp20importtorchvision.transforms.functional2122importtorch23importtorch.utils.data24fromlabmlimportlab,tracker,experiment,monit25fromlabml.configsimportBaseConfigs26fromlabml\_nn.helpers.deviceimportDeviceConfigs27fromlabml\_nn.unetimportUNet28fromlabml\_nn.unet.carvanaimportCarvanaDataset29fromtorchimportnn

#

Configurations

32classConfigs(BaseConfigs):

#

Device to train the model on. DeviceConfigs picks up an available CUDA device or defaults to CPU.

39device:torch.device=DeviceConfigs()

#

U-Net model

42model:UNet

#

Number of channels in the image. 3 for RGB.

45image\_channels:int=3

#

Number of channels in the output mask. 1 for binary mask.

47mask\_channels:int=1

#

Batch size

50batch\_size:int=1

#

Learning rate

52learning\_rate:float=2.5e-4

#

Number of training epochs

55epochs:int=4

#

Dataset

58dataset:CarvanaDataset

#

Dataloader

60data\_loader:torch.utils.data.DataLoader

#

Loss function

63loss\_func=nn.BCELoss()

#

Sigmoid function for binary classification

65sigmoid=nn.Sigmoid()

#

Adam optimizer

68optimizer:torch.optim.Adam

#

70definit(self):

#

Initialize the Carvana dataset

72self.dataset=CarvanaDataset(lab.get\_data\_path()/'carvana'/'train',73lab.get\_data\_path()/'carvana'/'train\_masks')

#

Initialize the model

75self.model=UNet(self.image\_channels,self.mask\_channels).to(self.device)

#

Create dataloader

78self.data\_loader=torch.utils.data.DataLoader(self.dataset,self.batch\_size,79shuffle=True,pin\_memory=True)

#

Create optimizer

81self.optimizer=torch.optim.Adam(self.model.parameters(),lr=self.learning\_rate)

#

Image logging

84tracker.set\_image("sample",True)

#

Sample images

[email protected]\_grad()87defsample(self,idx=-1):

#

Get a random sample

93x,\_=self.dataset[np.random.randint(len(self.dataset))]

#

Move data to device

95x=x.to(self.device)

#

Get predicted mask

98mask=self.sigmoid(self.model(x[None,:]))

#

Crop the image to the size of the mask

100x=torchvision.transforms.functional.center\_crop(x,[mask.shape[2],mask.shape[3]])

#

Log samples

102tracker.save('sample',x\*mask)

#

Train for an epoch

104deftrain(self):

#

Iterate through the dataset. Use mix to sample 50 times per epoch.

112for\_,(image,mask)inmonit.mix(('Train',self.data\_loader),(self.sample,list(range(50)))):

#

Increment global step

114tracker.add\_global\_step()

#

Move data to device

116image,mask=image.to(self.device),mask.to(self.device)

#

Make the gradients zero

119self.optimizer.zero\_grad()

#

Get predicted mask logits

121logits=self.model(image)

#

Crop the target mask to the size of the logits. Size of the logits will be smaller if we don't use padding in convolutional layers in the U-Net.

124mask=torchvision.transforms.functional.center\_crop(mask,[logits.shape[2],logits.shape[3]])

#

Calculate loss

126loss=self.loss\_func(self.sigmoid(logits),mask)

#

Compute gradients

128loss.backward()

#

Take an optimization step

130self.optimizer.step()

#

Track the loss

132tracker.save('loss',loss)

#

Training loop

134defrun(self):

#

138for\_inmonit.loop(self.epochs):

#

Train the model

140self.train()

#

New line in the console

142tracker.new\_line()

#

Save the model

#

146defmain():

#

Create experiment

148experiment.create(name='unet')

#

Create configurations

151configs=Configs()

#

Set configurations. You can override the defaults by passing the values in the dictionary.

154experiment.configs(configs,{})

#

Initialize

157configs.init()

#

Set models for saving and loading

160experiment.add\_pytorch\_models({'model':configs.model})

#

Start and run the training loop

163withexperiment.start():164configs.run()

#

168if\_\_name\_\_=='\_\_main\_\_':169main()

labml.ai