Back to Annotated Deep Learning Paper Implementations

U-Net

docs/unet/index.html

latest6.6 KB
Original Source

homeunet

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/unet/ init.py)

#

U-Net

This is an implementation of the U-Net model from the paper, U-Net: Convolutional Networks for Biomedical Image Segmentation.

U-Net consists of a contracting path and an expansive path. The contracting path is a series of convolutional layers and pooling layers, where the resolution of the feature map gets progressively reduced. Expansive path is a series of up-sampling layers and convolutional layers where the resolution of the feature map gets progressively increased.

At every step in the expansive path the corresponding feature map from the contracting path concatenated with the current feature map.

Here is the training code for an experiment that trains a U-Net on Carvana dataset.

27importtorch28importtorchvision.transforms.functional29fromtorchimportnn

#

Two 3×3 Convolution Layers

Each step in the contraction path and expansive path have two 3×3 convolutional layers followed by ReLU activations.

In the U-Net paper they used 0 padding, but we use 1 padding so that final feature map is not cropped.

32classDoubleConvolution(nn.Module):

#

  • in_channels is the number of input channels
  • out_channels is the number of output channels
43def\_\_init\_\_(self,in\_channels:int,out\_channels:int):

#

48super().\_\_init\_\_()

#

First 3×3 convolutional layer

51self.first=nn.Conv2d(in\_channels,out\_channels,kernel\_size=3,padding=1)52self.act1=nn.ReLU()

#

Second 3×3 convolutional layer

54self.second=nn.Conv2d(out\_channels,out\_channels,kernel\_size=3,padding=1)55self.act2=nn.ReLU()

#

57defforward(self,x:torch.Tensor):

#

Apply the two convolution layers and activations

59x=self.first(x)60x=self.act1(x)61x=self.second(x)62returnself.act2(x)

#

Down-sample

Each step in the contracting path down-samples the feature map with a 2×2 max pooling layer.

65classDownSample(nn.Module):

#

73def\_\_init\_\_(self):74super().\_\_init\_\_()

#

Max pooling layer

76self.pool=nn.MaxPool2d(2)

#

78defforward(self,x:torch.Tensor):79returnself.pool(x)

#

Up-sample

Each step in the expansive path up-samples the feature map with a 2×2 up-convolution.

82classUpSample(nn.Module):

#

89def\_\_init\_\_(self,in\_channels:int,out\_channels:int):90super().\_\_init\_\_()

#

Up-convolution

93self.up=nn.ConvTranspose2d(in\_channels,out\_channels,kernel\_size=2,stride=2)

#

95defforward(self,x:torch.Tensor):96returnself.up(x)

#

Crop and Concatenate the feature map

At every step in the expansive path the corresponding feature map from the contracting path concatenated with the current feature map.

99classCropAndConcat(nn.Module):

#

  • x current feature map in the expansive path
  • contracting_x corresponding feature map from the contracting path
106defforward(self,x:torch.Tensor,contracting\_x:torch.Tensor):

#

Crop the feature map from the contracting path to the size of the current feature map

113contracting\_x=torchvision.transforms.functional.center\_crop(contracting\_x,[x.shape[2],x.shape[3]])

#

Concatenate the feature maps

115x=torch.cat([x,contracting\_x],dim=1)

#

117returnx

#

U-Net

120classUNet(nn.Module):

#

  • in_channels number of channels in the input image
  • out_channels number of channels in the result feature map
124def\_\_init\_\_(self,in\_channels:int,out\_channels:int):

#

129super().\_\_init\_\_()

#

Double convolution layers for the contracting path. The number of features gets doubled at each step starting from 64.

133self.down\_conv=nn.ModuleList([DoubleConvolution(i,o)fori,oin134[(in\_channels,64),(64,128),(128,256),(256,512)]])

#

Down sampling layers for the contracting path

136self.down\_sample=nn.ModuleList([DownSample()for\_inrange(4)])

#

The two convolution layers at the lowest resolution (the bottom of the U).

139self.middle\_conv=DoubleConvolution(512,1024)

#

Up sampling layers for the expansive path. The number of features is halved with up-sampling.

143self.up\_sample=nn.ModuleList([UpSample(i,o)fori,oin144[(1024,512),(512,256),(256,128),(128,64)]])

#

Double convolution layers for the expansive path. Their input is the concatenation of the current feature map and the feature map from the contracting path. Therefore, the number of input features is double the number of features from up-sampling.

149self.up\_conv=nn.ModuleList([DoubleConvolution(i,o)fori,oin150[(1024,512),(512,256),(256,128),(128,64)]])

#

Crop and concatenate layers for the expansive path.

152self.concat=nn.ModuleList([CropAndConcat()for\_inrange(4)])

#

Final 1×1 convolution layer to produce the output

154self.final\_conv=nn.Conv2d(64,out\_channels,kernel\_size=1)

#

  • x input image
156defforward(self,x:torch.Tensor):

#

To collect the outputs of contracting path for later concatenation with the expansive path.

161pass\_through=[]

#

Contracting path

163foriinrange(len(self.down\_conv)):

#

Two 3×3 convolutional layers

165x=self.down\_conv[i](x)

#

Collect the output

167pass\_through.append(x)

#

Down-sample

169x=self.down\_sample[i](x)

#

Two 3×3 convolutional layers at the bottom of the U-Net

172x=self.middle\_conv(x)

#

Expansive path

175foriinrange(len(self.up\_conv)):

#

Up-sample

177x=self.up\_sample[i](x)

#

Concatenate the output of the contracting path

179x=self.concat[i](x,pass\_through.pop())

#

Two 3×3 convolutional layers

181x=self.up\_conv[i](x)

#

Final 1×1 convolution layer

184x=self.final\_conv(x)

#

187returnx

labml.ai