docs/normalization/group_norm/experiment.html
12importtorch.nnasnn1314fromlabmlimportexperiment15fromlabml.configsimportoption16fromlabml\_nn.experiments.cifar10importCIFAR10Configs,CIFAR10VGGModel
This derives from the generic VGG style architecture.
19classModel(CIFAR10VGGModel):
26defconv\_block(self,in\_channels,out\_channels)-\>nn.Module:27returnnn.Sequential(28nn.Conv2d(in\_channels,out\_channels,kernel\_size=3,padding=1),29fnorm.GroupNorm(self.groups,out\_channels),# new30nn.ReLU(inplace=True),31)
33def\_\_init\_\_(self,groups:int=32):34self.groups=groups# input param:groups to conv\_block35super().\_\_init\_\_([[64,64],[128,128],[256,256,256],[512,512,512],[512,512,512]])
38classConfigs(CIFAR10Configs):
Number of groups
40groups:int=16
43@option(Configs.model)44defmodel(c:Configs):
48returnModel(c.groups).to(c.device)
51defmain():
Create experiment
53experiment.create(name='cifar10',comment='group norm')
Create configurations
55conf=Configs()
Load configurations
57experiment.configs(conf,{58'optimizer.optimizer':'Adam',59'optimizer.learning\_rate':2.5e-4,60})
Start the experiment and run the training loop
62withexperiment.start():63conf.run()
67if\_\_name\_\_=='\_\_main\_\_':68main()