Back to Annotated Deep Learning Paper Implementations

Instance Normalization

docs/normalization/instance_norm/index.html

latest3.9 KB
Original Source

homenormalizationinstance_norm

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/normalization/instance_norm/ init.py)

#

Instance Normalization

This is a PyTorch implementation of Instance Normalization: The Missing Ingredient for Fast Stylization.

Instance normalization was introduced to improve style transfer. It is based on the observation that stylization should not depend on the contrast of the content image. The "contrast normalization" is

yt,i,j,k​=∑l=1H​∑m=1W​xt,i,l,m​xt,i,j,k​​

where x is a batch of images with dimensions image index t, feature channel i, and spatial position j,k.

Since it's hard for a convolutional network to learn "contrast normalization", this paper introduces instance normalization which does that.

Here's a CIFAR 10 classification model that uses instance normalization.

29importtorch30fromtorchimportnn

#

Instance Normalization Layer

Instance normalization layer IN normalizes the input X as follows:

When input X∈RB×C×H×W is a batch of image representations, where B is the batch size, C is the number of channels, H is the height and W is the width. γ∈RC and β∈RC. The affine transformation with gamma and beta are optional.

IN(X)=γH,WVar​[X]+ϵ​X−H,WE​[X]​+β

34classInstanceNorm(nn.Module):

#

  • channels is the number of features in the input
  • eps is ϵ, used in Var[X]+ϵ​ for numerical stability
  • affine is whether to scale and shift the normalized value
50def\_\_init\_\_(self,channels:int,\*,51eps:float=1e-5,affine:bool=True):

#

57super().\_\_init\_\_()5859self.channels=channels6061self.eps=eps62self.affine=affine

#

Create parameters for γ and β for scale and shift

64ifself.affine:65self.scale=nn.Parameter(torch.ones(channels))66self.shift=nn.Parameter(torch.zeros(channels))

#

x is a tensor of shape [batch_size, channels, *] . * denotes any number of (possibly 0) dimensions. For example, in an image (2D) convolution this will be [batch_size, channels, height, width]

68defforward(self,x:torch.Tensor):

#

Keep the original shape

76x\_shape=x.shape

#

Get the batch size

78batch\_size=x\_shape[0]

#

Sanity check to make sure the number of features is the same

80assertself.channels==x.shape[1]

#

Reshape into [batch_size, channels, n]

83x=x.view(batch\_size,self.channels,-1)

#

Calculate the mean across last dimension i.e. the means for each feature E[xt,i​]

87mean=x.mean(dim=[-1],keepdim=True)

#

Calculate the squared mean across first and last dimension; i.e. the means for each feature E[(xt,i2​]

90mean\_x2=(x\*\*2).mean(dim=[-1],keepdim=True)

#

Variance for each feature Var[xt,i​]=E[xt,i2​]−E[xt,i​]2

92var=mean\_x2-mean\*\*2

#

Normalize x^t,i​=Var[xt,i​]+ϵ​xt,i​−E[xt,i​]​

95x\_norm=(x-mean)/torch.sqrt(var+self.eps)96x\_norm=x\_norm.view(batch\_size,self.channels,-1)

#

Scale and shift yt,i​=γi​x^t,i​+βi​

99ifself.affine:100x\_norm=self.scale.view(1,-1,1)\*x\_norm+self.shift.view(1,-1,1)

#

Reshape to original and return

103returnx\_norm.view(x\_shape)

#

Simple test

106def\_test():

#

110fromlabml.loggerimportinspect111112x=torch.zeros([2,6,2,4])113inspect(x.shape)114bn=InstanceNorm(6)115116x=bn(x)117inspect(x.shape)

#

121if\_\_name\_\_=='\_\_main\_\_':122\_test()

labml.ai