Back to Annotated Deep Learning Paper Implementations

Train a ConvMixer on CIFAR 10

docs/conv_mixer/experiment.html

latest2.8 KB
Original Source

homeconv_mixer

View code on Github

#

Train a ConvMixer on CIFAR 10

This script trains a ConvMixer on CIFAR 10 dataset.

This is not an attempt to reproduce the results of the paper. The paper uses image augmentations present in PyTorch Image Models (timm) for training. We haven't done this for simplicity - which causes our validation accuracy to drop.

18fromlabmlimportexperiment19fromlabml.configsimportoption20fromlabml\_nn.experiments.cifar10importCIFAR10Configs

#

Configurations

We use CIFAR10Configs which defines all the dataset related configurations, optimizer, and a training loop.

23classConfigs(CIFAR10Configs):

#

Size of a patch, p

32patch\_size:int=2

#

Number of channels in patch embeddings, h

34d\_model:int=256

#

Number of ConvMixer layers or depth, d

36n\_layers:int=8

#

Kernel size of the depth-wise convolution, k

38kernel\_size:int=7

#

Number of classes in the task

40n\_classes:int=10

#

Create model

43@option(Configs.model)44def\_conv\_mixer(c:Configs):

#

48fromlabml\_nn.conv\_mixerimportConvMixerLayer,ConvMixer,ClassificationHead,PatchEmbeddings

#

Create ConvMixer

51returnConvMixer(ConvMixerLayer(c.d\_model,c.kernel\_size),c.n\_layers,52PatchEmbeddings(c.d\_model,c.patch\_size,3),53ClassificationHead(c.d\_model,c.n\_classes)).to(c.device)

#

56defmain():

#

Create experiment

58experiment.create(name='ConvMixer',comment='cifar10')

#

Create configurations

60conf=Configs()

#

Load configurations

62experiment.configs(conf,{

#

Optimizer

64'optimizer.optimizer':'Adam',65'optimizer.learning\_rate':2.5e-4,

#

Training epochs and batch size

68'epochs':150,69'train\_batch\_size':64,

#

Simple image augmentations

72'train\_dataset':'cifar10\_train\_augmented',

#

Do not augment images for validation

74'valid\_dataset':'cifar10\_valid\_no\_augment',75})

#

Set model for saving/loading

77experiment.add\_pytorch\_models({'model':conf.model})

#

Start the experiment and run the training loop

79withexperiment.start():80conf.run()

#

84if\_\_name\_\_=='\_\_main\_\_':85main()

labml.ai