Back to Annotated Deep Learning Paper Implementations

CIFAR10 Experiment to try Weight Standardization and Batch-Channel Normalization

docs/normalization/weight_standardization/experiment.html

latest1.9 KB
Original Source

homenormalizationweight_standardization

View code on Github

#

CIFAR10 Experiment to try Weight Standardization and Batch-Channel Normalization

12importtorch.nnasnn1314fromlabmlimportexperiment15fromlabml.configsimportoption16fromlabml\_nn.experiments.cifar10importCIFAR10Configs,CIFAR10VGGModel17fromlabml\_nn.normalization.batch\_channel\_normimportBatchChannelNorm18fromlabml\_nn.normalization.weight\_standardization.conv2dimportConv2d

#

VGG model for CIFAR-10 classification

This derives from the generic VGG style architecture.

21classModel(CIFAR10VGGModel):

#

28defconv\_block(self,in\_channels,out\_channels)-\>nn.Module:29returnnn.Sequential(30Conv2d(in\_channels,out\_channels,kernel\_size=3,padding=1),31BatchChannelNorm(out\_channels,32),32nn.ReLU(inplace=True),33)

#

35def\_\_init\_\_(self):36super().\_\_init\_\_([[64,64],[128,128],[256,256,256],[512,512,512],[512,512,512]])

#

Create model

39@option(CIFAR10Configs.model)40def\_model(c:CIFAR10Configs):

#

44returnModel().to(c.device)

#

47defmain():

#

Create experiment

49experiment.create(name='cifar10',comment='weight standardization')

#

Create configurations

51conf=CIFAR10Configs()

#

Load configurations

53experiment.configs(conf,{54'optimizer.optimizer':'Adam',55'optimizer.learning\_rate':2.5e-4,56'train\_batch\_size':64,57})

#

Start the experiment and run the training loop

59withexperiment.start():60conf.run()

#

64if\_\_name\_\_=='\_\_main\_\_':65main()

labml.ai