docs/normalization/batch_norm/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/normalization/batch_norm/ init.py)
This is a PyTorch implementation of Batch Normalization from paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.
The paper defines Internal Covariate Shift as the change in the distribution of network activations due to the change in network parameters during training. For example, let's say there are two layers l1 and l2. During the beginning of the training l1 outputs (inputs to l2) could be in distribution N(0.5,1). Then, after some training steps, it could move to N(0.6,1.5). This is internal covariate shift.
Internal covariate shift will adversely affect training speed because the later layers (l2 in the above example) have to adapt to this shifted distribution.
By stabilizing the distribution, batch normalization minimizes the internal covariate shift.
It is known that whitening improves training speed and convergence. Whitening is linearly transforming inputs to have zero mean, unit variance, and be uncorrelated.
Normalizing outside the gradient computation using pre-computed (detached) means and variances doesn't work. For instance. (ignoring variance), let x^=x−E[x] where x=u+b and b is a trained bias and E[x] is an outside gradient computation (pre-computed constant).
Note that x^ has no effect on b. Therefore, b will increase or decrease based ∂x∂L, and keep on growing indefinitely in each training update. The paper notes that similar explosions happen with variances.
Whitening is computationally expensive because you need to de-correlate and the gradients must flow through the full whitening calculation.
The paper introduces a simplified version which they call Batch Normalization. First simplification is that it normalizes each feature independently to have zero mean and unit variance: x^(k)=Var[x(k)]x(k)−E[x(k)] where x=(x(1)...x(d)) is the d-dimensional input.
The second simplification is to use estimates of mean E[x(k)] and variance Var[x(k)] from the mini-batch for normalization; instead of calculating the mean and variance across the whole dataset.
Normalizing each feature to zero mean and unit variance could affect what the layer can represent. As an example paper illustrates that, if the inputs to a sigmoid are normalized most of it will be within [−1,1] range where the sigmoid is linear. To overcome this each feature is scaled and shifted by two trained parameters γ(k) and β(k). y(k)=γ(k)x^(k)+β(k) where y(k) is the output of the batch normalization layer.
Note that when applying batch normalization after a linear transform like Wu+b the bias parameter b gets cancelled due to normalization. So you can and should omit bias parameter in linear transforms right before the batch normalization.
Batch normalization also makes the back propagation invariant to the scale of the weights and empirically it improves generalization, so it has regularization effects too.
We need to know E[x(k)] and Var[x(k)] in order to perform the normalization. So during inference, you either need to go through the whole (or part of) dataset and find the mean and variance, or you can use an estimate calculated during training. The usual practice is to calculate an exponential moving average of mean and variance during the training phase and use that for inference.
Here's the training code and a notebook for training a CNN classifier that uses batch normalization for MNIST dataset.
97importtorch98fromtorchimportnn
Batch normalization layer BN normalizes the input X as follows:
When input X∈RB×C×H×W is a batch of image representations, where B is the batch size, C is the number of channels, H is the height and W is the width. γ∈RC and β∈RC. BN(X)=γB,H,WVar[X]+ϵX−B,H,WE[X]+β
When input X∈RB×C is a batch of embeddings, where B is the batch size and C is the number of features. γ∈RC and β∈RC. BN(X)=γBVar[X]+ϵX−BE[X]+β
When input X∈RB×C×L is a batch of a sequence embeddings, where B is the batch size, C is the number of features, and L is the length of the sequence. γ∈RC and β∈RC. BN(X)=γB,LVar[X]+ϵX−B,LE[X]+β
102classBatchNorm(nn.Module):
channels is the number of features in the inputeps is ϵ, used in Var[x(k)]+ϵ for numerical stabilitymomentum is the momentum in taking the exponential moving averageaffine is whether to scale and shift the normalized valuetrack_running_stats is whether to calculate the moving averages or mean and varianceWe've tried to use the same names for arguments as PyTorch BatchNorm implementation.
130def\_\_init\_\_(self,channels:int,\*,131eps:float=1e-5,momentum:float=0.1,132affine:bool=True,track\_running\_stats:bool=True):
142super().\_\_init\_\_()143144self.channels=channels145146self.eps=eps147self.momentum=momentum148self.affine=affine149self.track\_running\_stats=track\_running\_stats
Create parameters for γ and β for scale and shift
151ifself.affine:152self.scale=nn.Parameter(torch.ones(channels))153self.shift=nn.Parameter(torch.zeros(channels))
Create buffers to store exponential moving averages of mean E[x(k)] and variance Var[x(k)]
156ifself.track\_running\_stats:157self.register\_buffer('exp\_mean',torch.zeros(channels))158self.register\_buffer('exp\_var',torch.ones(channels))
x is a tensor of shape [batch_size, channels, *] . * denotes any number of (possibly 0) dimensions. For example, in an image (2D) convolution this will be [batch_size, channels, height, width]
160defforward(self,x:torch.Tensor):
Keep the original shape
168x\_shape=x.shape
Get the batch size
170batch\_size=x\_shape[0]
Sanity check to make sure the number of features is the same
172assertself.channels==x.shape[1]
Reshape into [batch_size, channels, n]
175x=x.view(batch\_size,self.channels,-1)
We will calculate the mini-batch mean and variance if we are in training mode or if we have not tracked exponential moving averages
179ifself.trainingornotself.track\_running\_stats:
Calculate the mean across first and last dimension; i.e. the means for each feature E[x(k)]
182mean=x.mean(dim=[0,2])
Calculate the squared mean across first and last dimension; i.e. the means for each feature E[(x(k))2]
185mean\_x2=(x\*\*2).mean(dim=[0,2])
Variance for each feature Var[x(k)]=E[(x(k))2]−E[x(k)]2
187var=mean\_x2-mean\*\*2
Update exponential moving averages
190ifself.trainingandself.track\_running\_stats:191self.exp\_mean=(1-self.momentum)\*self.exp\_mean+self.momentum\*mean192self.exp\_var=(1-self.momentum)\*self.exp\_var+self.momentum\*var
Use exponential moving averages as estimates
194else:195mean=self.exp\_mean196var=self.exp\_var
Normalize x^(k)=Var[x(k)]+ϵx(k)−E[x(k)]
199x\_norm=(x-mean.view(1,-1,1))/torch.sqrt(var+self.eps).view(1,-1,1)
Scale and shift y(k)=γ(k)x^(k)+β(k)
201ifself.affine:202x\_norm=self.scale.view(1,-1,1)\*x\_norm+self.shift.view(1,-1,1)
Reshape to original and return
205returnx\_norm.view(x\_shape)
Simple test
208def\_test():
212fromlabml.loggerimportinspect213214x=torch.zeros([2,3,2,4])215inspect(x.shape)216bn=BatchNorm(3)217218x=bn(x)219inspect(x.shape)220inspect(bn.exp\_var.shape)
224if\_\_name\_\_=='\_\_main\_\_':225\_test()