docs/conv_mixer/experiment.html
This script trains a ConvMixer on CIFAR 10 dataset.
This is not an attempt to reproduce the results of the paper. The paper uses image augmentations present in PyTorch Image Models (timm) for training. We haven't done this for simplicity - which causes our validation accuracy to drop.
18fromlabmlimportexperiment19fromlabml.configsimportoption20fromlabml\_nn.experiments.cifar10importCIFAR10Configs
We use CIFAR10Configs which defines all the dataset related configurations, optimizer, and a training loop.
23classConfigs(CIFAR10Configs):
Size of a patch, p
32patch\_size:int=2
Number of channels in patch embeddings, h
34d\_model:int=256
Number of ConvMixer layers or depth, d
36n\_layers:int=8
Kernel size of the depth-wise convolution, k
38kernel\_size:int=7
Number of classes in the task
40n\_classes:int=10
43@option(Configs.model)44def\_conv\_mixer(c:Configs):
48fromlabml\_nn.conv\_mixerimportConvMixerLayer,ConvMixer,ClassificationHead,PatchEmbeddings
Create ConvMixer
51returnConvMixer(ConvMixerLayer(c.d\_model,c.kernel\_size),c.n\_layers,52PatchEmbeddings(c.d\_model,c.patch\_size,3),53ClassificationHead(c.d\_model,c.n\_classes)).to(c.device)
56defmain():
Create experiment
58experiment.create(name='ConvMixer',comment='cifar10')
Create configurations
60conf=Configs()
Load configurations
62experiment.configs(conf,{
Optimizer
64'optimizer.optimizer':'Adam',65'optimizer.learning\_rate':2.5e-4,
Training epochs and batch size
68'epochs':150,69'train\_batch\_size':64,
Simple image augmentations
72'train\_dataset':'cifar10\_train\_augmented',
Do not augment images for validation
74'valid\_dataset':'cifar10\_valid\_no\_augment',75})
Set model for saving/loading
77experiment.add\_pytorch\_models({'model':conf.model})
Start the experiment and run the training loop
79withexperiment.start():80conf.run()
84if\_\_name\_\_=='\_\_main\_\_':85main()