docs/normalization/batch_channel_norm/index.html
homenormalizationbatch_channel_norm
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/normalization/batch_channel_norm/ init.py)
This is a PyTorch implementation of Batch-Channel Normalization from the paper Micro-Batch Training with Batch-Channel Normalization and Weight Standardization. We also have an annotated implementation of Weight Standardization.
Batch-Channel Normalization performs batch normalization followed by a channel normalization (similar to a Group Normalization. When the batch size is small a running mean and variance is used for batch normalization.
Here is the training code for training a VGG network that uses weight standardization to classify CIFAR-10 data.
25importtorch26fromtorchimportnn2728fromlabml\_nn.normalization.batch\_normimportBatchNorm
This first performs a batch normalization - either normal batch norm or a batch norm with estimated mean and variance (exponential mean/variance over multiple batches). Then a channel normalization performed.
31classBatchChannelNorm(nn.Module):
channels is the number of features in the inputgroups is the number of groups the features are divided intoeps is ϵ, used in Var[x(k)]+ϵ for numerical stabilitymomentum is the momentum in taking the exponential moving averageestimate is whether to use running mean and variance for batch norm41def\_\_init\_\_(self,channels:int,groups:int,42eps:float=1e-5,momentum:float=0.1,estimate:bool=True):
50super().\_\_init\_\_()
Use estimated batch norm or normal batch norm.
53ifestimate:54self.batch\_norm=EstimatedBatchNorm(channels,55eps=eps,momentum=momentum)56else:57self.batch\_norm=BatchNorm(channels,58eps=eps,momentum=momentum)
Channel normalization
61self.channel\_norm=ChannelNorm(channels,groups,eps)
63defforward(self,x):64x=self.batch\_norm(x)65returnself.channel\_norm(x)
When input X∈RB×C×H×W is a batch of image representations, where B is the batch size, C is the number of channels, H is the height and W is the width. γ∈RC and β∈RC.
X˙⋅,C,⋅,⋅=γCσ^CX⋅,C,⋅,⋅−μ^C+βC
where,
μ^Cσ^C2⟵(1−r)μ^C+rBHW1b,h,w∑Xb,c,h,w⟵(1−r)σ^C2+rBHW1b,h,w∑(Xb,c,h,w−μ^C)2
are the running mean and variances. r is the momentum for calculating the exponential mean.
68classEstimatedBatchNorm(nn.Module):
channels is the number of features in the inputeps is ϵ, used in Var[x(k)]+ϵ for numerical stabilitymomentum is the momentum in taking the exponential moving averageestimate is whether to use running mean and variance for batch norm89def\_\_init\_\_(self,channels:int,90eps:float=1e-5,momentum:float=0.1,affine:bool=True):
97super().\_\_init\_\_()9899self.eps=eps100self.momentum=momentum101self.affine=affine102self.channels=channels
Channel wise transformation parameters
105ifself.affine:106self.scale=nn.Parameter(torch.ones(channels))107self.shift=nn.Parameter(torch.zeros(channels))
Tensors for μ^C and σ^C2
110self.register\_buffer('exp\_mean',torch.zeros(channels))111self.register\_buffer('exp\_var',torch.ones(channels))
x is a tensor of shape [batch_size, channels, *] . * denotes any number of (possibly 0) dimensions. For example, in an image (2D) convolution this will be [batch_size, channels, height, width]
113defforward(self,x:torch.Tensor):
Keep old shape
121x\_shape=x.shape
Get the batch size
123batch\_size=x\_shape[0]
Sanity check to make sure the number of features is correct
126assertself.channels==x.shape[1]
Reshape into [batch_size, channels, n]
129x=x.view(batch\_size,self.channels,-1)
Update μ^C and σ^C2 in training mode only
132ifself.training:
No backpropagation through μ^C and σ^C2
134withtorch.no\_grad():
Calculate the mean across first and last dimensions; BHW1b,h,w∑Xb,c,h,w
137mean=x.mean(dim=[0,2])
Calculate the squared mean across first and last dimensions; BHW1b,h,w∑Xb,c,h,w2
140mean\_x2=(x\*\*2).mean(dim=[0,2])
Variance for each feature BHW1b,h,w∑(Xb,c,h,w−μ^C)2
143var=mean\_x2-mean\*\*2
Update exponential moving averages
μ^Cσ^C2⟵(1−r)μ^C+rBHW1b,h,w∑Xb,c,h,w⟵(1−r)σ^C2+rBHW1b,h,w∑(Xb,c,h,w−μ^C)2
151self.exp\_mean=(1-self.momentum)\*self.exp\_mean+self.momentum\*mean152self.exp\_var=(1-self.momentum)\*self.exp\_var+self.momentum\*var
Normalize σ^CX⋅,C,⋅,⋅−μ^C
156x\_norm=(x-self.exp\_mean.view(1,-1,1))/torch.sqrt(self.exp\_var+self.eps).view(1,-1,1)
Scale and shift γCσ^CX⋅,C,⋅,⋅−μ^C+βC
161ifself.affine:162x\_norm=self.scale.view(1,-1,1)\*x\_norm+self.shift.view(1,-1,1)
Reshape to original and return
165returnx\_norm.view(x\_shape)
This is similar to Group Normalization but affine transform is done group wise.
168classChannelNorm(nn.Module):
groups is the number of groups the features are divided intochannels is the number of features in the inputeps is ϵ, used in Var[x(k)]+ϵ for numerical stabilityaffine is whether to scale and shift the normalized value175def\_\_init\_\_(self,channels,groups,176eps:float=1e-5,affine:bool=True):
183super().\_\_init\_\_()184self.channels=channels185self.groups=groups186self.eps=eps187self.affine=affine
Parameters for affine transformation.
Note that these transforms are per group, unlike in group norm where they are transformed channel-wise.
192ifself.affine:193self.scale=nn.Parameter(torch.ones(groups))194self.shift=nn.Parameter(torch.zeros(groups))
x is a tensor of shape [batch_size, channels, *] . * denotes any number of (possibly 0) dimensions. For example, in an image (2D) convolution this will be [batch_size, channels, height, width]
196defforward(self,x:torch.Tensor):
Keep the original shape
205x\_shape=x.shape
Get the batch size
207batch\_size=x\_shape[0]
Sanity check to make sure the number of features is the same
209assertself.channels==x.shape[1]
Reshape into [batch_size, groups, n]
212x=x.view(batch\_size,self.groups,-1)
Calculate the mean across last dimension; i.e. the means for each sample and channel group E[x(iN,iG)]
216mean=x.mean(dim=[-1],keepdim=True)
Calculate the squared mean across last dimension; i.e. the means for each sample and channel group E[x(iN,iG)2]
219mean\_x2=(x\*\*2).mean(dim=[-1],keepdim=True)
Variance for each sample and feature group Var[x(iN,iG)]=E[x(iN,iG)2]−E[x(iN,iG)]2
222var=mean\_x2-mean\*\*2
Normalize x^(iN,iG)=Var[x(iN,iG)]+ϵx(iN,iG)−E[x(iN,iG)]
227x\_norm=(x-mean)/torch.sqrt(var+self.eps)
Scale and shift group-wise yiG=γiGx^iG+βiG
231ifself.affine:232x\_norm=self.scale.view(1,-1,1)\*x\_norm+self.shift.view(1,-1,1)
Reshape to original and return
235returnx\_norm.view(x\_shape)