Back to Annotated Deep Learning Paper Implementations

Train a small model on CIFAR 10

docs/distillation/small.html

latest2.6 KB
Original Source

homedistillation

View code on Github

#

Train a small model on CIFAR 10

This trains a small model on CIFAR 10 to test how much distillation benefits.

13importtorch.nnasnn1415fromlabmlimportexperiment,logger16fromlabml.configsimportoption17fromlabml\_nn.experiments.cifar10importCIFAR10Configs,CIFAR10VGGModel18fromlabml\_nn.normalization.batch\_normimportBatchNorm

#

Configurations

We use CIFAR10Configs which defines all the dataset related configurations, optimizer, and a training loop.

21classConfigs(CIFAR10Configs):

#

28pass

#

VGG style model for CIFAR-10 classification

This derives from the generic VGG style architecture.

31classSmallModel(CIFAR10VGGModel):

#

Create a convolution layer and the activations

38defconv\_block(self,in\_channels,out\_channels)-\>nn.Module:

#

42returnnn.Sequential(

#

Convolution layer

44nn.Conv2d(in\_channels,out\_channels,kernel\_size=3,padding=1),

#

Batch normalization

46BatchNorm(out\_channels,track\_running\_stats=False),

#

ReLU activation

48nn.ReLU(inplace=True),49)

#

51def\_\_init\_\_(self):

#

Create a model with given convolution sizes (channels)

53super().\_\_init\_\_([[32,32],[64,64],[128],[128],[128]])

#

Create model

56@option(Configs.model)57def\_small\_model(c:Configs):

#

61returnSmallModel().to(c.device)

#

64defmain():

#

Create experiment

66experiment.create(name='cifar10',comment='small model')

#

Create configurations

68conf=Configs()

#

Load configurations

70experiment.configs(conf,{71'optimizer.optimizer':'Adam',72'optimizer.learning\_rate':2.5e-4,73})

#

Set model for saving/loading

75experiment.add\_pytorch\_models({'model':conf.model})

#

Print number of parameters in the model

77logger.inspect(params=(sum(p.numel()forpinconf.model.parameters()ifp.requires\_grad)))

#

Start the experiment and run the training loop

79withexperiment.start():80conf.run()

#

84if\_\_name\_\_=='\_\_main\_\_':85main()

labml.ai