Back to Annotated Deep Learning Paper Implementations

MNIST Experiment for Batch Normalization

docs/normalization/batch_norm/mnist.html

latest2.6 KB
Original Source

homenormalizationbatch_norm

View code on Github

#

MNIST Experiment for Batch Normalization

12importtorch.nnasnn13importtorch.nn.functionalasF14importtorch.utils.data1516fromlabmlimportexperiment17fromlabml.configsimportoption18fromlabml\_nn.experiments.mnistimportMNISTConfigs19fromlabml\_nn.normalization.batch\_normimportBatchNorm

#

Model definition

22classModel(nn.Module):

#

27def\_\_init\_\_(self):28super().\_\_init\_\_()

#

Note that we omit the bias parameter

30self.conv1=nn.Conv2d(1,20,5,1,bias=False)

#

Batch normalization with 20 channels (output of convolution layer). The input to this layer will have shape [batch_size, 20, height(24), width(24)]

33self.bn1=BatchNorm(20)

#

35self.conv2=nn.Conv2d(20,50,5,1,bias=False)

#

Batch normalization with 50 channels. The input to this layer will have shape [batch_size, 50, height(8), width(8)]

38self.bn2=BatchNorm(50)

#

40self.fc1=nn.Linear(4\*4\*50,500,bias=False)

#

Batch normalization with 500 channels (output of fully connected layer). The input to this layer will have shape [batch_size, 500]

43self.bn3=BatchNorm(500)

#

45self.fc2=nn.Linear(500,10)

#

47defforward(self,x:torch.Tensor):48x=F.relu(self.bn1(self.conv1(x)))49x=F.max\_pool2d(x,2,2)50x=F.relu(self.bn2(self.conv2(x)))51x=F.max\_pool2d(x,2,2)52x=x.view(-1,4\*4\*50)53x=F.relu(self.bn3(self.fc1(x)))54returnself.fc2(x)

#

Create model

We use MNISTConfigs configurations and set a new function to calculate the model.

57@option(MNISTConfigs.model)58defmodel(c:MNISTConfigs):

#

65returnModel().to(c.device)

#

68defmain():

#

Create experiment

70experiment.create(name='mnist\_batch\_norm')

#

Create configurations

72conf=MNISTConfigs()

#

Load configurations

74experiment.configs(conf,{75'optimizer.optimizer':'Adam',76'optimizer.learning\_rate':0.001,77})

#

Start the experiment and run the training loop

79withexperiment.start():80conf.run()

#

84if\_\_name\_\_=='\_\_main\_\_':85main()

labml.ai