docs/experiments/mnist.html
11importtorch.nnasnn12importtorch.utils.data1314fromlabmlimporttracker15fromlabml.configsimportoption16fromlabml\_nn.helpers.datasetsimportMNISTConfigsasMNISTDatasetConfigs17fromlabml\_nn.helpers.deviceimportDeviceConfigs18fromlabml\_nn.helpers.metricsimportAccuracy19fromlabml\_nn.helpers.trainerimportTrainValidConfigs,BatchIndex20fromlabml\_nn.optimizers.configsimportOptimizerConfigs
23classMNISTConfigs(MNISTDatasetConfigs,TrainValidConfigs):
Optimizer
31optimizer:torch.optim.Adam
Training device
33device:torch.device=DeviceConfigs()
Classification model
36model:nn.Module
Number of epochs to train for
38epochs:int=10
Number of times to switch between training and validation within an epoch
41inner\_iterations=10
Accuracy function
44accuracy=Accuracy()
Loss function
46loss\_func=nn.CrossEntropyLoss()
48definit(self):
Set tracker configurations
53tracker.set\_scalar("loss.\*",True)54tracker.set\_scalar("accuracy.\*",True)
Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.
59self.state\_modules=[self.accuracy]
61defstep(self,batch:any,batch\_idx:BatchIndex):
Training/Evaluation mode
67self.model.train(self.mode.is\_train)
Move data to the device
70data,target=batch[0].to(self.device),batch[1].to(self.device)
Update global step (number of samples processed) when in training mode
73ifself.mode.is\_train:74tracker.add\_global\_step(len(data))
Get model outputs.
77output=self.model(data)
Calculate and log loss
80loss=self.loss\_func(output,target)81tracker.add("loss.",loss)
Calculate and log accuracy
84self.accuracy(output,target)85self.accuracy.track()
Train the model
88ifself.mode.is\_train:
Calculate gradients
90loss.backward()
Take optimizer step
92self.optimizer.step()
Log the model parameters and gradients on last batch of every epoch
94ifbatch\_idx.is\_last:95tracker.add('model',self.model)
Clear the gradients
97self.optimizer.zero\_grad()
Save the tracked metrics
100tracker.save()
103@option(MNISTConfigs.optimizer)104def\_optimizer(c:MNISTConfigs):
108opt\_conf=OptimizerConfigs()109opt\_conf.parameters=c.model.parameters()110opt\_conf.optimizer='Adam'111returnopt\_conf