Back to Annotated Deep Learning Paper Implementations

Batch-Channel Normalization

docs/normalization/batch_channel_norm/index.html

latest8.3 KB
Original Source

homenormalizationbatch_channel_norm

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/normalization/batch_channel_norm/ init.py)

#

Batch-Channel Normalization

This is a PyTorch implementation of Batch-Channel Normalization from the paper Micro-Batch Training with Batch-Channel Normalization and Weight Standardization. We also have an annotated implementation of Weight Standardization.

Batch-Channel Normalization performs batch normalization followed by a channel normalization (similar to a Group Normalization. When the batch size is small a running mean and variance is used for batch normalization.

Here is the training code for training a VGG network that uses weight standardization to classify CIFAR-10 data.

25importtorch26fromtorchimportnn2728fromlabml\_nn.normalization.batch\_normimportBatchNorm

#

Batch-Channel Normalization

This first performs a batch normalization - either normal batch norm or a batch norm with estimated mean and variance (exponential mean/variance over multiple batches). Then a channel normalization performed.

31classBatchChannelNorm(nn.Module):

#

  • channels is the number of features in the input
  • groups is the number of groups the features are divided into
  • eps is ϵ, used in Var[x(k)]+ϵ​ for numerical stability
  • momentum is the momentum in taking the exponential moving average
  • estimate is whether to use running mean and variance for batch norm
41def\_\_init\_\_(self,channels:int,groups:int,42eps:float=1e-5,momentum:float=0.1,estimate:bool=True):

#

50super().\_\_init\_\_()

#

Use estimated batch norm or normal batch norm.

53ifestimate:54self.batch\_norm=EstimatedBatchNorm(channels,55eps=eps,momentum=momentum)56else:57self.batch\_norm=BatchNorm(channels,58eps=eps,momentum=momentum)

#

Channel normalization

61self.channel\_norm=ChannelNorm(channels,groups,eps)

#

63defforward(self,x):64x=self.batch\_norm(x)65returnself.channel\_norm(x)

#

Estimated Batch Normalization

When input X∈RB×C×H×W is a batch of image representations, where B is the batch size, C is the number of channels, H is the height and W is the width. γ∈RC and β∈RC.

X˙⋅,C,⋅,⋅​=γC​σ^C​X⋅,C,⋅,⋅​−μ^​C​​+βC​

where,

μ^​C​σ^C2​​⟵(1−r)μ^​C​+rBHW1​b,h,w∑​Xb,c,h,w​⟵(1−r)σ^C2​+rBHW1​b,h,w∑​(Xb,c,h,w​−μ^​C​)2​

are the running mean and variances. r is the momentum for calculating the exponential mean.

68classEstimatedBatchNorm(nn.Module):

#

  • channels is the number of features in the input
  • eps is ϵ, used in Var[x(k)]+ϵ​ for numerical stability
  • momentum is the momentum in taking the exponential moving average
  • estimate is whether to use running mean and variance for batch norm
89def\_\_init\_\_(self,channels:int,90eps:float=1e-5,momentum:float=0.1,affine:bool=True):

#

97super().\_\_init\_\_()9899self.eps=eps100self.momentum=momentum101self.affine=affine102self.channels=channels

#

Channel wise transformation parameters

105ifself.affine:106self.scale=nn.Parameter(torch.ones(channels))107self.shift=nn.Parameter(torch.zeros(channels))

#

Tensors for μ^​C​ and σ^C2​

110self.register\_buffer('exp\_mean',torch.zeros(channels))111self.register\_buffer('exp\_var',torch.ones(channels))

#

x is a tensor of shape [batch_size, channels, *] . * denotes any number of (possibly 0) dimensions. For example, in an image (2D) convolution this will be [batch_size, channels, height, width]

113defforward(self,x:torch.Tensor):

#

Keep old shape

121x\_shape=x.shape

#

Get the batch size

123batch\_size=x\_shape[0]

#

Sanity check to make sure the number of features is correct

126assertself.channels==x.shape[1]

#

Reshape into [batch_size, channels, n]

129x=x.view(batch\_size,self.channels,-1)

#

Update μ^​C​ and σ^C2​ in training mode only

132ifself.training:

#

No backpropagation through μ^​C​ and σ^C2​

134withtorch.no\_grad():

#

Calculate the mean across first and last dimensions; BHW1​b,h,w∑​Xb,c,h,w​

137mean=x.mean(dim=[0,2])

#

Calculate the squared mean across first and last dimensions; BHW1​b,h,w∑​Xb,c,h,w2​

140mean\_x2=(x\*\*2).mean(dim=[0,2])

#

Variance for each feature BHW1​b,h,w∑​(Xb,c,h,w​−μ^​C​)2

143var=mean\_x2-mean\*\*2

#

Update exponential moving averages

μ^​C​σ^C2​​⟵(1−r)μ^​C​+rBHW1​b,h,w∑​Xb,c,h,w​⟵(1−r)σ^C2​+rBHW1​b,h,w∑​(Xb,c,h,w​−μ^​C​)2​

151self.exp\_mean=(1-self.momentum)\*self.exp\_mean+self.momentum\*mean152self.exp\_var=(1-self.momentum)\*self.exp\_var+self.momentum\*var

#

Normalize σ^C​X⋅,C,⋅,⋅​−μ^​C​​

156x\_norm=(x-self.exp\_mean.view(1,-1,1))/torch.sqrt(self.exp\_var+self.eps).view(1,-1,1)

#

Scale and shift γC​σ^C​X⋅,C,⋅,⋅​−μ^​C​​+βC​

161ifself.affine:162x\_norm=self.scale.view(1,-1,1)\*x\_norm+self.shift.view(1,-1,1)

#

Reshape to original and return

165returnx\_norm.view(x\_shape)

#

Channel Normalization

This is similar to Group Normalization but affine transform is done group wise.

168classChannelNorm(nn.Module):

#

  • groups is the number of groups the features are divided into
  • channels is the number of features in the input
  • eps is ϵ, used in Var[x(k)]+ϵ​ for numerical stability
  • affine is whether to scale and shift the normalized value
175def\_\_init\_\_(self,channels,groups,176eps:float=1e-5,affine:bool=True):

#

183super().\_\_init\_\_()184self.channels=channels185self.groups=groups186self.eps=eps187self.affine=affine

#

Parameters for affine transformation.

Note that these transforms are per group, unlike in group norm where they are transformed channel-wise.

192ifself.affine:193self.scale=nn.Parameter(torch.ones(groups))194self.shift=nn.Parameter(torch.zeros(groups))

#

x is a tensor of shape [batch_size, channels, *] . * denotes any number of (possibly 0) dimensions. For example, in an image (2D) convolution this will be [batch_size, channels, height, width]

196defforward(self,x:torch.Tensor):

#

Keep the original shape

205x\_shape=x.shape

#

Get the batch size

207batch\_size=x\_shape[0]

#

Sanity check to make sure the number of features is the same

209assertself.channels==x.shape[1]

#

Reshape into [batch_size, groups, n]

212x=x.view(batch\_size,self.groups,-1)

#

Calculate the mean across last dimension; i.e. the means for each sample and channel group E[x(iN​,iG​)​]

216mean=x.mean(dim=[-1],keepdim=True)

#

Calculate the squared mean across last dimension; i.e. the means for each sample and channel group E[x(iN​,iG​)2​]

219mean\_x2=(x\*\*2).mean(dim=[-1],keepdim=True)

#

Variance for each sample and feature group Var[x(iN​,iG​)​]=E[x(iN​,iG​)2​]−E[x(iN​,iG​)​]2

222var=mean\_x2-mean\*\*2

#

Normalize x^(iN​,iG​)​=Var[x(iN​,iG​)​]+ϵ​x(iN​,iG​)​−E[x(iN​,iG​)​]​

227x\_norm=(x-mean)/torch.sqrt(var+self.eps)

#

Scale and shift group-wise yiG​​=γiG​​x^iG​​+βiG​​

231ifself.affine:232x\_norm=self.scale.view(1,-1,1)\*x\_norm+self.shift.view(1,-1,1)

#

Reshape to original and return

235returnx\_norm.view(x\_shape)

labml.ai