Back to Annotated Deep Learning Paper Implementations

Train a large model on CIFAR 10

docs/distillation/large.html

latest2.7 KB
Original Source

homedistillation

View code on Github

#

Train a large model on CIFAR 10

This trains a large model on CIFAR 10 for distillation.

13importtorch.nnasnn1415fromlabmlimportexperiment,logger16fromlabml.configsimportoption17fromlabml\_nn.experiments.cifar10importCIFAR10Configs,CIFAR10VGGModel18fromlabml\_nn.normalization.batch\_normimportBatchNorm

#

Configurations

We use CIFAR10Configs which defines all the dataset related configurations, optimizer, and a training loop.

21classConfigs(CIFAR10Configs):

#

28pass

#

VGG style model for CIFAR-10 classification

This derives from the generic VGG style architecture.

31classLargeModel(CIFAR10VGGModel):

#

Create a convolution layer and the activations

38defconv\_block(self,in\_channels,out\_channels)-\>nn.Module:

#

42returnnn.Sequential(

#

Dropout

44nn.Dropout(0.1),

#

Convolution layer

46nn.Conv2d(in\_channels,out\_channels,kernel\_size=3,padding=1),

#

Batch normalization

48BatchNorm(out\_channels,track\_running\_stats=False),

#

ReLU activation

50nn.ReLU(inplace=True),51)

#

53def\_\_init\_\_(self):

#

Create a model with given convolution sizes (channels)

55super().\_\_init\_\_([[64,64],[128,128],[256,256,256],[512,512,512],[512,512,512]])

#

Create model

58@option(Configs.model)59def\_large\_model(c:Configs):

#

63returnLargeModel().to(c.device)

#

66defmain():

#

Create experiment

68experiment.create(name='cifar10',comment='large model')

#

Create configurations

70conf=Configs()

#

Load configurations

72experiment.configs(conf,{73'optimizer.optimizer':'Adam',74'optimizer.learning\_rate':2.5e-4,75'is\_save\_models':True,76'epochs':20,77})

#

Set model for saving/loading

79experiment.add\_pytorch\_models({'model':conf.model})

#

Print number of parameters in the model

81logger.inspect(params=(sum(p.numel()forpinconf.model.parameters()ifp.requires\_grad)))

#

Start the experiment and run the training loop

83withexperiment.start():84conf.run()

#

88if\_\_name\_\_=='\_\_main\_\_':89main()

labml.ai