docs/optimizers/mnist_experiment.html
9importtorch.nnasnn10importtorch.utils.data1112fromlabmlimportexperiment,tracker13fromlabml.configsimportoption14fromlabml\_nn.helpers.datasetsimportMNISTConfigs15fromlabml\_nn.helpers.deviceimportDeviceConfigs16fromlabml\_nn.helpers.metricsimportAccuracy17fromlabml\_nn.helpers.trainerimportTrainValidConfigs,BatchIndex18fromlabml\_nn.optimizers.configsimportOptimizerConfigs
21classModel(nn.Module):
26def\_\_init\_\_(self):27super().\_\_init\_\_()28self.conv1=nn.Conv2d(1,20,5,1)29self.pool1=nn.MaxPool2d(2)30self.conv2=nn.Conv2d(20,50,5,1)31self.pool2=nn.MaxPool2d(2)32self.fc1=nn.Linear(16\*50,500)33self.fc2=nn.Linear(500,10)34self.activation=nn.ReLU()
36defforward(self,x):37x=self.activation(self.conv1(x))38x=self.pool1(x)39x=self.activation(self.conv2(x))40x=self.pool2(x)41x=self.activation(self.fc1(x.view(-1,16\*50)))42returnself.fc2(x)
45classConfigs(MNISTConfigs,TrainValidConfigs):
49optimizer:torch.optim.Adam50model:nn.Module51device:torch.device=DeviceConfigs()52epochs:int=105354is\_save\_models=True55model:nn.Module56inner\_iterations=105758accuracy\_func=Accuracy()59loss\_func=nn.CrossEntropyLoss()
61definit(self):62tracker.set\_queue("loss.\*",20,True)63tracker.set\_scalar("accuracy.\*",True)64self.state\_modules=[self.accuracy\_func]
66defstep(self,batch:any,batch\_idx:BatchIndex):
Get the batch
68data,target=batch[0].to(self.device),batch[1].to(self.device)
Add global step if we are in training mode
71ifself.mode.is\_train:72tracker.add\_global\_step(len(data))
Run the model
75output=self.model(data)
Calculate the loss
78loss=self.loss\_func(output,target)
Calculate the accuracy
80self.accuracy\_func(output,target)
Log the loss
82tracker.add("loss.",loss)
Optimize if we are in training mode
85ifself.mode.is\_train:
Calculate the gradients
87loss.backward()
Take optimizer step
90self.optimizer.step()
Log the parameter and gradient L2 norms once per epoch
92ifbatch\_idx.is\_last:93tracker.add('model',self.model)94tracker.add('optimizer',(self.optimizer,{'model':self.model}))
Clear the gradients
96self.optimizer.zero\_grad()
Save logs
99tracker.save()
Create a configurable optimizer. We can change the optimizer type and hyper-parameters using configurations.
102@option(Configs.model)103defmodel(c:Configs):104returnModel().to(c.device)105106107@option(Configs.optimizer)108def\_optimizer(c:Configs):
113opt\_conf=OptimizerConfigs()114opt\_conf.parameters=c.model.parameters()115returnopt\_conf
118defmain():119conf=Configs()120conf.inner\_iterations=10121experiment.create(name='mnist\_ada\_belief')122experiment.configs(conf,{'inner\_iterations':10,
Specify the optimizer
124'optimizer.optimizer':'Adam',125'optimizer.learning\_rate':1.5e-4})126experiment.add\_pytorch\_models(dict(model=conf.model))127withexperiment.start():128conf.run()129130131if\_\_name\_\_=='\_\_main\_\_':132main()