Back to Annotated Deep Learning Paper Implementations

CIFAR10 Experiment for Group Normalization

docs/normalization/group_norm/experiment.html

latest1.9 KB
Original Source

homenormalizationgroup_norm

View code on Github

#

CIFAR10 Experiment for Group Normalization

12importtorch.nnasnn1314fromlabmlimportexperiment15fromlabml.configsimportoption16fromlabml\_nn.experiments.cifar10importCIFAR10Configs,CIFAR10VGGModel

#

VGG model for CIFAR-10 classification

This derives from the generic VGG style architecture.

19classModel(CIFAR10VGGModel):

#

26defconv\_block(self,in\_channels,out\_channels)-\>nn.Module:27returnnn.Sequential(28nn.Conv2d(in\_channels,out\_channels,kernel\_size=3,padding=1),29fnorm.GroupNorm(self.groups,out\_channels),# new30nn.ReLU(inplace=True),31)

#

33def\_\_init\_\_(self,groups:int=32):34self.groups=groups# input param:groups to conv\_block35super().\_\_init\_\_([[64,64],[128,128],[256,256,256],[512,512,512],[512,512,512]])

#

38classConfigs(CIFAR10Configs):

#

Number of groups

40groups:int=16

#

Create model

43@option(Configs.model)44defmodel(c:Configs):

#

48returnModel(c.groups).to(c.device)

#

51defmain():

#

Create experiment

53experiment.create(name='cifar10',comment='group norm')

#

Create configurations

55conf=Configs()

#

Load configurations

57experiment.configs(conf,{58'optimizer.optimizer':'Adam',59'optimizer.learning\_rate':2.5e-4,60})

#

Start the experiment and run the training loop

62withexperiment.start():63conf.run()

#

67if\_\_name\_\_=='\_\_main\_\_':68main()

labml.ai