docs/capsule_networks/mnist.html
This is an annotated PyTorch code to classify MNIST digits with PyTorch.
This paper implements the experiment described in paper Dynamic Routing Between Capsules.
14fromtypingimportAny1516importtorch.nnasnn17importtorch.nn.functionalasF18importtorch.utils.data19fromlabmlimportexperiment,tracker20fromlabml.configsimportoption21fromlabml\_nn.capsule\_networksimportSquash,Router,MarginLoss22fromlabml\_nn.helpers.datasetsimportMNISTConfigs23fromlabml\_nn.helpers.metricsimportAccuracyDirect24fromlabml\_nn.helpers.trainerimportSimpleTrainValidConfigs,BatchIndex
27classMNISTCapsuleNetworkModel(nn.Module):
32def\_\_init\_\_(self):33super().\_\_init\_\_()
First convolution layer has 256, 9×9 convolution kernels
35self.conv1=nn.Conv2d(in\_channels=1,out\_channels=256,kernel\_size=9,stride=1)
The second layer (Primary Capsules) s a convolutional capsule layer with 32 channels of convolutional 8D capsules (8 features per capsule). That is, each primary capsule contains 8 convolutional units with a 9 × 9 kernel and a stride of 2. In order to implement this we create a convolutional layer with 32×8 channels and reshape and permutate its output to get the capsules of 8 features each.
41self.conv2=nn.Conv2d(in\_channels=256,out\_channels=32\*8,kernel\_size=9,stride=2,padding=0)42self.squash=Squash()
Routing layer gets the 32×6×6 primary capsules and produces 10 capsules. Each of the primary capsules have 8 features, while output capsules (Digit Capsules) have 16 features. The routing algorithm iterates 3 times.
48self.digit\_capsules=Router(32\*6\*6,10,8,16,3)
This is the decoder mentioned in the paper. It takes the outputs of the 10 digit capsules, each with 16 features to reproduce the image. It goes through linear layers of sizes 512 and 1024 with ReLU activations.
53self.decoder=nn.Sequential(54nn.Linear(16\*10,512),55nn.ReLU(),56nn.Linear(512,1024),57nn.ReLU(),58nn.Linear(1024,784),59nn.Sigmoid()60)
data are the MNIST images, with shape [batch_size, 1, 28, 28]
62defforward(self,data:torch.Tensor):
Pass through the first convolution layer. Output of this layer has shape [batch_size, 256, 20, 20]
68x=F.relu(self.conv1(data))
Pass through the second convolution layer. Output of this has shape [batch_size, 32 * 8, 6, 6] . Note that this layer has a stride length of 2.
72x=self.conv2(x)
Resize and permutate to get the capsules
75caps=x.view(x.shape[0],8,32\*6\*6).permute(0,2,1)
Squash the capsules
77caps=self.squash(caps)
Take them through the router to get digit capsules. This has shape [batch_size, 10, 16] .
80caps=self.digit\_capsules(caps)
Get masks for reconstructioon
83withtorch.no\_grad():
The prediction by the capsule network is the capsule with longest length
85pred=(caps\*\*2).sum(-1).argmax(-1)
Create a mask to maskout all the other capsules
87mask=torch.eye(10,device=data.device)[pred]
Mask the digit capsules to get only the capsule that made the prediction and take it through decoder to get reconstruction
91reconstructions=self.decoder((caps\*mask[:,:,None]).view(x.shape[0],-1))
Reshape the reconstruction to match the image dimensions
93reconstructions=reconstructions.view(-1,1,28,28)9495returncaps,reconstructions,pred
Configurations with MNIST data and Train & Validation setup
98classConfigs(MNISTConfigs,SimpleTrainValidConfigs):
102epochs:int=10103model:nn.Module='capsule\_network\_model'104reconstruction\_loss=nn.MSELoss()105margin\_loss=MarginLoss(n\_labels=10)106accuracy=AccuracyDirect()
108definit(self):
Print losses and accuracy to screen
110tracker.set\_scalar('loss.\*',True)111tracker.set\_scalar('accuracy.\*',True)
We need to set the metrics to calculate them for the epoch for training and validation
114self.state\_modules=[self.accuracy]
This method gets called by the trainer
116defstep(self,batch:Any,batch\_idx:BatchIndex):
Set the model mode
121self.model.train(self.mode.is\_train)
Get the images and labels and move them to the model's device
124data,target=batch[0].to(self.device),batch[1].to(self.device)
Increment step in training mode
127ifself.mode.is\_train:128tracker.add\_global\_step(len(data))
Run the model
131caps,reconstructions,pred=self.model(data)
Calculate the total loss
134loss=self.margin\_loss(caps,target)+0.0005\*self.reconstruction\_loss(reconstructions,data)135tracker.add("loss.",loss)
Call accuracy metric
138self.accuracy(pred,target)139140ifself.mode.is\_train:141loss.backward()142143self.optimizer.step()
Log parameters and gradients
145ifbatch\_idx.is\_last:146tracker.add('model',self.model)147self.optimizer.zero\_grad()148149tracker.save()
Set the model
152@option(Configs.model)153defcapsule\_network\_model(c:Configs):
155returnMNISTCapsuleNetworkModel().to(c.device)
Run the experiment
158defmain():
162experiment.create(name='capsule\_network\_mnist')163conf=Configs()164experiment.configs(conf,{'optimizer.optimizer':'Adam',165'optimizer.learning\_rate':1e-3})166167experiment.add\_pytorch\_models({'model':conf.model})168169withexperiment.start():170conf.run()171172173if\_\_name\_\_=='\_\_main\_\_':174main()