docs/distillation/large.html
This trains a large model on CIFAR 10 for distillation.
13importtorch.nnasnn1415fromlabmlimportexperiment,logger16fromlabml.configsimportoption17fromlabml\_nn.experiments.cifar10importCIFAR10Configs,CIFAR10VGGModel18fromlabml\_nn.normalization.batch\_normimportBatchNorm
We use CIFAR10Configs which defines all the dataset related configurations, optimizer, and a training loop.
21classConfigs(CIFAR10Configs):
28pass
This derives from the generic VGG style architecture.
31classLargeModel(CIFAR10VGGModel):
Create a convolution layer and the activations
38defconv\_block(self,in\_channels,out\_channels)-\>nn.Module:
42returnnn.Sequential(
Dropout
44nn.Dropout(0.1),
Convolution layer
46nn.Conv2d(in\_channels,out\_channels,kernel\_size=3,padding=1),
Batch normalization
48BatchNorm(out\_channels,track\_running\_stats=False),
ReLU activation
50nn.ReLU(inplace=True),51)
53def\_\_init\_\_(self):
Create a model with given convolution sizes (channels)
55super().\_\_init\_\_([[64,64],[128,128],[256,256,256],[512,512,512],[512,512,512]])
58@option(Configs.model)59def\_large\_model(c:Configs):
63returnLargeModel().to(c.device)
66defmain():
Create experiment
68experiment.create(name='cifar10',comment='large model')
Create configurations
70conf=Configs()
Load configurations
72experiment.configs(conf,{73'optimizer.optimizer':'Adam',74'optimizer.learning\_rate':2.5e-4,75'is\_save\_models':True,76'epochs':20,77})
Set model for saving/loading
79experiment.add\_pytorch\_models({'model':conf.model})
Print number of parameters in the model
81logger.inspect(params=(sum(p.numel()forpinconf.model.parameters()ifp.requires\_grad)))
Start the experiment and run the training loop
83withexperiment.start():84conf.run()
88if\_\_name\_\_=='\_\_main\_\_':89main()