docs/distillation/small.html
This trains a small model on CIFAR 10 to test how much distillation benefits.
13importtorch.nnasnn1415fromlabmlimportexperiment,logger16fromlabml.configsimportoption17fromlabml\_nn.experiments.cifar10importCIFAR10Configs,CIFAR10VGGModel18fromlabml\_nn.normalization.batch\_normimportBatchNorm
We use CIFAR10Configs which defines all the dataset related configurations, optimizer, and a training loop.
21classConfigs(CIFAR10Configs):
28pass
This derives from the generic VGG style architecture.
31classSmallModel(CIFAR10VGGModel):
Create a convolution layer and the activations
38defconv\_block(self,in\_channels,out\_channels)-\>nn.Module:
42returnnn.Sequential(
Convolution layer
44nn.Conv2d(in\_channels,out\_channels,kernel\_size=3,padding=1),
Batch normalization
46BatchNorm(out\_channels,track\_running\_stats=False),
ReLU activation
48nn.ReLU(inplace=True),49)
51def\_\_init\_\_(self):
Create a model with given convolution sizes (channels)
53super().\_\_init\_\_([[32,32],[64,64],[128],[128],[128]])
56@option(Configs.model)57def\_small\_model(c:Configs):
61returnSmallModel().to(c.device)
64defmain():
Create experiment
66experiment.create(name='cifar10',comment='small model')
Create configurations
68conf=Configs()
Load configurations
70experiment.configs(conf,{71'optimizer.optimizer':'Adam',72'optimizer.learning\_rate':2.5e-4,73})
Set model for saving/loading
75experiment.add\_pytorch\_models({'model':conf.model})
Print number of parameters in the model
77logger.inspect(params=(sum(p.numel()forpinconf.model.parameters()ifp.requires\_grad)))
Start the experiment and run the training loop
79withexperiment.start():80conf.run()
84if\_\_name\_\_=='\_\_main\_\_':85main()