Back to Annotated Deep Learning Paper Implementations

Train a Vision Transformer (ViT) on CIFAR 10

docs/transformers/vit/experiment.html

latest3.0 KB
Original Source

hometransformersvit

View code on Github

#

Train a Vision Transformer (ViT) on CIFAR 10

11fromlabmlimportexperiment12fromlabml.configsimportoption13fromlabml\_nn.experiments.cifar10importCIFAR10Configs14fromlabml\_nn.transformersimportTransformerConfigs

#

Configurations

We use CIFAR10Configs which defines all the dataset related configurations, optimizer, and a training loop.

17classConfigs(CIFAR10Configs):

#

Transformer configurations to get transformer layer

27transformer:TransformerConfigs

#

Size of a patch

30patch\_size:int=4

#

Size of the hidden layer in classification head

32n\_hidden\_classification:int=2048

#

Number of classes in the task

34n\_classes:int=10

#

Create transformer configs

37@option(Configs.transformer)38def\_transformer():

#

42returnTransformerConfigs()

#

Create model

45@option(Configs.model)46def\_vit(c:Configs):

#

50fromlabml\_nn.transformers.vitimportVisionTransformer,LearnedPositionalEmbeddings,ClassificationHead,\51PatchEmbeddings

#

Transformer size from Transformer configurations

54d\_model=c.transformer.d\_model

#

Create a vision transformer

56returnVisionTransformer(c.transformer.encoder\_layer,c.transformer.n\_layers,57PatchEmbeddings(d\_model,c.patch\_size,3),58LearnedPositionalEmbeddings(d\_model),59ClassificationHead(d\_model,c.n\_hidden\_classification,c.n\_classes)).to(c.device)

#

62defmain():

#

Create experiment

64experiment.create(name='ViT',comment='cifar10')

#

Create configurations

66conf=Configs()

#

Load configurations

68experiment.configs(conf,{

#

Optimizer

70'optimizer.optimizer':'Adam',71'optimizer.learning\_rate':2.5e-4,

#

Transformer embedding size

74'transformer.d\_model':512,

#

Training epochs and batch size

77'epochs':32,78'train\_batch\_size':64,

#

Augment CIFAR 10 images for training

81'train\_dataset':'cifar10\_train\_augmented',

#

Do not augment CIFAR 10 images for validation

83'valid\_dataset':'cifar10\_valid\_no\_augment',84})

#

Set model for saving/loading

86experiment.add\_pytorch\_models({'model':conf.model})

#

Start the experiment and run the training loop

88withexperiment.start():89conf.run()

#

93if\_\_name\_\_=='\_\_main\_\_':94main()

labml.ai