docs/resnet/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/resnet/ init.py)
This is a PyTorch implementation of the paper Deep Residual Learning for Image Recognition.
ResNets train layers as residual functions to overcome the degradation problem. The degradation problem is the accuracy of deep neural networks degrading when the number of layers becomes very high. The accuracy increases as the number of layers increase, then saturates, and then starts to degrade.
The paper argues that deeper models should perform at least as well as shallower models because the extra layers can just learn to perform an identity mapping.
If H(x) is the mapping that needs to be learned by a few layers, they train the residual function
F(x)=H(x)−x
instead. And the original function becomes F(x)+x.
In this case, learning identity mapping for H(x) is equivalent to learning F(x) to be 0, which is easier to learn.
In the parameterized form this can be written as,
F(x,{Wi})+x
and when the feature map sizes of F(x,Wi) and x are different the paper suggests doing a linear projection, with learned weights Ws.
F(x,{Wi})+Wsx
Paper experimented with zero padding instead of linear projections and found linear projections to work better. Also when the feature map sizes match they found identity mapping to be better than linear projections.
F should have more than one layer, otherwise the sum F(x,{Wi})+Wsx also won't have non-linearities and will be like a linear layer.
Here is the training code for training a ResNet on CIFAR-10.
55fromtypingimportList,Optional5657importtorch58fromtorchimportnn
This does the Wsx projection described above.
62classShortcutProjection(nn.Module):
in_channels is the number of channels in xout_channels is the number of channels in F(x,{Wi})stride is the stride length in the convolution operation for F. We do the same stride on the shortcut connection, to match the feature-map size.69def\_\_init\_\_(self,in\_channels:int,out\_channels:int,stride:int):
76super().\_\_init\_\_()
Convolution layer for linear projection Wsx
79self.conv=nn.Conv2d(in\_channels,out\_channels,kernel\_size=1,stride=stride)
Paper suggests adding batch normalization after each convolution operation
81self.bn=nn.BatchNorm2d(out\_channels)
83defforward(self,x:torch.Tensor):
Convolution and batch normalization
85returnself.bn(self.conv(x))
This implements the residual block described in the paper. It has two 3×3 convolution layers.
The first convolution layer maps from in_channels to out_channels , where the out_channels is higher than in_channels when we reduce the feature map size with a stride length greater than 1.
The second convolution layer maps from out_channels to out_channels and always has a stride length of 1.
Both convolution layers are followed by batch normalization.
88classResidualBlock(nn.Module):
in_channels is the number of channels in xout_channels is the number of output channelsstride is the stride length in the convolution operation.109def\_\_init\_\_(self,in\_channels:int,out\_channels:int,stride:int):
115super().\_\_init\_\_()
First 3×3 convolution layer, this maps to out_channels
118self.conv1=nn.Conv2d(in\_channels,out\_channels,kernel\_size=3,stride=stride,padding=1)
Batch normalization after the first convolution
120self.bn1=nn.BatchNorm2d(out\_channels)
First activation function (ReLU)
122self.act1=nn.ReLU()
Second 3×3 convolution layer
125self.conv2=nn.Conv2d(out\_channels,out\_channels,kernel\_size=3,stride=1,padding=1)
Batch normalization after the second convolution
127self.bn2=nn.BatchNorm2d(out\_channels)
Shortcut connection should be a projection if the stride length is not 1 or if the number of channels change
131ifstride!=1orin\_channels!=out\_channels:
Projection Wsx
133self.shortcut=ShortcutProjection(in\_channels,out\_channels,stride)134else:
Identity x
136self.shortcut=nn.Identity()
Second activation function (ReLU) (after adding the shortcut)
139self.act2=nn.ReLU()
x is the input of shape [batch_size, in_channels, height, width]141defforward(self,x:torch.Tensor):
Get the shortcut connection
146shortcut=self.shortcut(x)
First convolution and activation
148x=self.act1(self.bn1(self.conv1(x)))
Second convolution
150x=self.bn2(self.conv2(x))
Activation function after adding the shortcut
152returnself.act2(x+shortcut)
This implements the bottleneck block described in the paper. It has 1×1, 3×3, and 1×1 convolution layers.
The first convolution layer maps from in_channels to bottleneck_channels with a 1×1 convolution, where the bottleneck_channels is lower than in_channels .
The second 3×3 convolution layer maps from bottleneck_channels to bottleneck_channels . This can have a stride length greater than 1 when we want to compress the feature map size.
The third, final 1×1 convolution layer maps to out_channels . out_channels is higher than in_channels if the stride length is greater than 1; otherwise, outchannels is equal to in_channels .
bottleneck_channels is less than in_channels and the 3×3 convolution is performed on this shrunk space (hence the bottleneck). The two 1×1 convolution decreases and increases the number of channels.
155classBottleneckResidualBlock(nn.Module):
in_channels is the number of channels in xbottleneck_channels is the number of channels for the 3×3 convlutionout_channels is the number of output channelsstride is the stride length in the 3×3 convolution operation.183def\_\_init\_\_(self,in\_channels:int,bottleneck\_channels:int,out\_channels:int,stride:int):
190super().\_\_init\_\_()
First 1×1 convolution layer, this maps to bottleneck_channels
193self.conv1=nn.Conv2d(in\_channels,bottleneck\_channels,kernel\_size=1,stride=1)
Batch normalization after the first convolution
195self.bn1=nn.BatchNorm2d(bottleneck\_channels)
First activation function (ReLU)
197self.act1=nn.ReLU()
Second 3×3 convolution layer
200self.conv2=nn.Conv2d(bottleneck\_channels,bottleneck\_channels,kernel\_size=3,stride=stride,padding=1)
Batch normalization after the second convolution
202self.bn2=nn.BatchNorm2d(bottleneck\_channels)
Second activation function (ReLU)
204self.act2=nn.ReLU()
Third 1×1 convolution layer, this maps to out_channels .
207self.conv3=nn.Conv2d(bottleneck\_channels,out\_channels,kernel\_size=1,stride=1)
Batch normalization after the second convolution
209self.bn3=nn.BatchNorm2d(out\_channels)
Shortcut connection should be a projection if the stride length is not 1 or if the number of channels change
213ifstride!=1orin\_channels!=out\_channels:
Projection Wsx
215self.shortcut=ShortcutProjection(in\_channels,out\_channels,stride)216else:
Identity x
218self.shortcut=nn.Identity()
Second activation function (ReLU) (after adding the shortcut)
221self.act3=nn.ReLU()
x is the input of shape [batch_size, in_channels, height, width]223defforward(self,x:torch.Tensor):
Get the shortcut connection
228shortcut=self.shortcut(x)
First convolution and activation
230x=self.act1(self.bn1(self.conv1(x)))
Second convolution and activation
232x=self.act2(self.bn2(self.conv2(x)))
Third convolution
234x=self.bn3(self.conv3(x))
Activation function after adding the shortcut
236returnself.act3(x+shortcut)
This is a the base of the resnet model without the final linear layer and softmax for classification.
The resnet is made of stacked residual blocks or bottleneck residual blocks. The feature map size is halved after a few blocks with a block of stride length 2. The number of channels is increased when the feature map size is reduced. Finally the feature map is average pooled to get a vector representation.
239classResNetBase(nn.Module):
n_blocks is a list of of number of blocks for each feature map size.n_channels is the number of channels for each feature map size.bottlenecks is the number of channels the bottlenecks. If this is None , residual blocks are used.img_channels is the number of channels in the input.first_kernel_size is the kernel size of the initial convolution layer253def\_\_init\_\_(self,n\_blocks:List[int],n\_channels:List[int],254bottlenecks:Optional[List[int]]=None,255img\_channels:int=3,first\_kernel\_size:int=7):
264super().\_\_init\_\_()
Number of blocks and number of channels for each feature map size
267assertlen(n\_blocks)==len(n\_channels)
If bottleneck residual blocks are used, the number of channels in bottlenecks should be provided for each feature map size
270assertbottlenecksisNoneorlen(bottlenecks)==len(n\_channels)
Initial convolution layer maps from img_channels to number of channels in the first residual block (n_channels[0] )
274self.conv=nn.Conv2d(img\_channels,n\_channels[0],275kernel\_size=first\_kernel\_size,stride=2,padding=first\_kernel\_size//2)
Batch norm after initial convolution
277self.bn=nn.BatchNorm2d(n\_channels[0])
List of blocks
280blocks=[]
Number of channels from previous layer (or block)
282prev\_channels=n\_channels[0]
Loop through each feature map size
284fori,channelsinenumerate(n\_channels):
The first block for the new feature map size, will have a stride length of 2 except fro the very first block
287stride=2iflen(blocks)==0else1288289ifbottlenecksisNone:
residual blocks that maps from prev_channels to channels
291blocks.append(ResidualBlock(prev\_channels,channels,stride=stride))292else:
bottleneck residual blocks that maps from prev_channels to channels
295blocks.append(BottleneckResidualBlock(prev\_channels,bottlenecks[i],channels,296stride=stride))
Change the number of channels
299prev\_channels=channels
Add rest of the blocks - no change in feature map size or channels
301for\_inrange(n\_blocks[i]-1):302ifbottlenecksisNone:
304blocks.append(ResidualBlock(channels,channels,stride=1))305else:
307blocks.append(BottleneckResidualBlock(channels,bottlenecks[i],channels,stride=1))
Stack the blocks
310self.blocks=nn.Sequential(\*blocks)
x has shape [batch_size, img_channels, height, width]312defforward(self,x:torch.Tensor):
Initial convolution and batch normalization
318x=self.bn(self.conv(x))
Residual (or bottleneck) blocks
320x=self.blocks(x)
Change x from shape [batch_size, channels, h, w] to [batch_size, channels, h * w]
322x=x.view(x.shape[0],x.shape[1],-1)
Global average pooling
324returnx.mean(dim=-1)