Back to Annotated Deep Learning Paper Implementations

MNIST Experiment

docs/experiments/mnist.html

latest3.3 KB
Original Source

homeexperiments

View code on Github

#

MNIST Experiment

11importtorch.nnasnn12importtorch.utils.data1314fromlabmlimporttracker15fromlabml.configsimportoption16fromlabml\_nn.helpers.datasetsimportMNISTConfigsasMNISTDatasetConfigs17fromlabml\_nn.helpers.deviceimportDeviceConfigs18fromlabml\_nn.helpers.metricsimportAccuracy19fromlabml\_nn.helpers.trainerimportTrainValidConfigs,BatchIndex20fromlabml\_nn.optimizers.configsimportOptimizerConfigs

#

Trainer configurations

23classMNISTConfigs(MNISTDatasetConfigs,TrainValidConfigs):

#

Optimizer

31optimizer:torch.optim.Adam

#

Training device

33device:torch.device=DeviceConfigs()

#

Classification model

36model:nn.Module

#

Number of epochs to train for

38epochs:int=10

#

Number of times to switch between training and validation within an epoch

41inner\_iterations=10

#

Accuracy function

44accuracy=Accuracy()

#

Loss function

46loss\_func=nn.CrossEntropyLoss()

#

Initialization

48definit(self):

#

Set tracker configurations

53tracker.set\_scalar("loss.\*",True)54tracker.set\_scalar("accuracy.\*",True)

#

Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.

59self.state\_modules=[self.accuracy]

#

Training or validation step

61defstep(self,batch:any,batch\_idx:BatchIndex):

#

Training/Evaluation mode

67self.model.train(self.mode.is\_train)

#

Move data to the device

70data,target=batch[0].to(self.device),batch[1].to(self.device)

#

Update global step (number of samples processed) when in training mode

73ifself.mode.is\_train:74tracker.add\_global\_step(len(data))

#

Get model outputs.

77output=self.model(data)

#

Calculate and log loss

80loss=self.loss\_func(output,target)81tracker.add("loss.",loss)

#

Calculate and log accuracy

84self.accuracy(output,target)85self.accuracy.track()

#

Train the model

88ifself.mode.is\_train:

#

Calculate gradients

90loss.backward()

#

Take optimizer step

92self.optimizer.step()

#

Log the model parameters and gradients on last batch of every epoch

94ifbatch\_idx.is\_last:95tracker.add('model',self.model)

#

Clear the gradients

97self.optimizer.zero\_grad()

#

Save the tracked metrics

100tracker.save()

#

Default optimizer configurations

103@option(MNISTConfigs.optimizer)104def\_optimizer(c:MNISTConfigs):

#

108opt\_conf=OptimizerConfigs()109opt\_conf.parameters=c.model.parameters()110opt\_conf.optimizer='Adam'111returnopt\_conf

labml.ai