Back to Annotated Deep Learning Paper Implementations

Classify MNIST digits with Capsule Networks

docs/capsule_networks/mnist.html

latest6.2 KB
Original Source

homecapsule_networks

View code on Github

#

Classify MNIST digits with Capsule Networks

This is an annotated PyTorch code to classify MNIST digits with PyTorch.

This paper implements the experiment described in paper Dynamic Routing Between Capsules.

14fromtypingimportAny1516importtorch.nnasnn17importtorch.nn.functionalasF18importtorch.utils.data19fromlabmlimportexperiment,tracker20fromlabml.configsimportoption21fromlabml\_nn.capsule\_networksimportSquash,Router,MarginLoss22fromlabml\_nn.helpers.datasetsimportMNISTConfigs23fromlabml\_nn.helpers.metricsimportAccuracyDirect24fromlabml\_nn.helpers.trainerimportSimpleTrainValidConfigs,BatchIndex

#

Model for classifying MNIST digits

27classMNISTCapsuleNetworkModel(nn.Module):

#

32def\_\_init\_\_(self):33super().\_\_init\_\_()

#

First convolution layer has 256, 9×9 convolution kernels

35self.conv1=nn.Conv2d(in\_channels=1,out\_channels=256,kernel\_size=9,stride=1)

#

The second layer (Primary Capsules) s a convolutional capsule layer with 32 channels of convolutional 8D capsules (8 features per capsule). That is, each primary capsule contains 8 convolutional units with a 9 × 9 kernel and a stride of 2. In order to implement this we create a convolutional layer with 32×8 channels and reshape and permutate its output to get the capsules of 8 features each.

41self.conv2=nn.Conv2d(in\_channels=256,out\_channels=32\*8,kernel\_size=9,stride=2,padding=0)42self.squash=Squash()

#

Routing layer gets the 32×6×6 primary capsules and produces 10 capsules. Each of the primary capsules have 8 features, while output capsules (Digit Capsules) have 16 features. The routing algorithm iterates 3 times.

48self.digit\_capsules=Router(32\*6\*6,10,8,16,3)

#

This is the decoder mentioned in the paper. It takes the outputs of the 10 digit capsules, each with 16 features to reproduce the image. It goes through linear layers of sizes 512 and 1024 with ReLU activations.

53self.decoder=nn.Sequential(54nn.Linear(16\*10,512),55nn.ReLU(),56nn.Linear(512,1024),57nn.ReLU(),58nn.Linear(1024,784),59nn.Sigmoid()60)

#

data are the MNIST images, with shape [batch_size, 1, 28, 28]

62defforward(self,data:torch.Tensor):

#

Pass through the first convolution layer. Output of this layer has shape [batch_size, 256, 20, 20]

68x=F.relu(self.conv1(data))

#

Pass through the second convolution layer. Output of this has shape [batch_size, 32 * 8, 6, 6] . Note that this layer has a stride length of 2.

72x=self.conv2(x)

#

Resize and permutate to get the capsules

75caps=x.view(x.shape[0],8,32\*6\*6).permute(0,2,1)

#

Squash the capsules

77caps=self.squash(caps)

#

Take them through the router to get digit capsules. This has shape [batch_size, 10, 16] .

80caps=self.digit\_capsules(caps)

#

Get masks for reconstructioon

83withtorch.no\_grad():

#

The prediction by the capsule network is the capsule with longest length

85pred=(caps\*\*2).sum(-1).argmax(-1)

#

Create a mask to maskout all the other capsules

87mask=torch.eye(10,device=data.device)[pred]

#

Mask the digit capsules to get only the capsule that made the prediction and take it through decoder to get reconstruction

91reconstructions=self.decoder((caps\*mask[:,:,None]).view(x.shape[0],-1))

#

Reshape the reconstruction to match the image dimensions

93reconstructions=reconstructions.view(-1,1,28,28)9495returncaps,reconstructions,pred

#

Configurations with MNIST data and Train & Validation setup

98classConfigs(MNISTConfigs,SimpleTrainValidConfigs):

#

102epochs:int=10103model:nn.Module='capsule\_network\_model'104reconstruction\_loss=nn.MSELoss()105margin\_loss=MarginLoss(n\_labels=10)106accuracy=AccuracyDirect()

#

108definit(self):

#

Print losses and accuracy to screen

110tracker.set\_scalar('loss.\*',True)111tracker.set\_scalar('accuracy.\*',True)

#

We need to set the metrics to calculate them for the epoch for training and validation

114self.state\_modules=[self.accuracy]

#

This method gets called by the trainer

116defstep(self,batch:Any,batch\_idx:BatchIndex):

#

Set the model mode

121self.model.train(self.mode.is\_train)

#

Get the images and labels and move them to the model's device

124data,target=batch[0].to(self.device),batch[1].to(self.device)

#

Increment step in training mode

127ifself.mode.is\_train:128tracker.add\_global\_step(len(data))

#

Run the model

131caps,reconstructions,pred=self.model(data)

#

Calculate the total loss

134loss=self.margin\_loss(caps,target)+0.0005\*self.reconstruction\_loss(reconstructions,data)135tracker.add("loss.",loss)

#

Call accuracy metric

138self.accuracy(pred,target)139140ifself.mode.is\_train:141loss.backward()142143self.optimizer.step()

#

Log parameters and gradients

145ifbatch\_idx.is\_last:146tracker.add('model',self.model)147self.optimizer.zero\_grad()148149tracker.save()

#

Set the model

152@option(Configs.model)153defcapsule\_network\_model(c:Configs):

#

155returnMNISTCapsuleNetworkModel().to(c.device)

#

Run the experiment

158defmain():

#

162experiment.create(name='capsule\_network\_mnist')163conf=Configs()164experiment.configs(conf,{'optimizer.optimizer':'Adam',165'optimizer.learning\_rate':1e-3})166167experiment.add\_pytorch\_models({'model':conf.model})168169withexperiment.start():170conf.run()171172173if\_\_name\_\_=='\_\_main\_\_':174main()

labml.ai