Back to Annotated Deep Learning Paper Implementations

Mnist

labml_nn/capsule_networks/mnist.ipynb

latest1.4 KB
Original Source

Training a Capsule Network to classify MNIST digits

This is an experiment to train a Capsule Network to classify MNIST digits using PyTorch.

Install the labml-nn package

!pip install labml-nn

Imports

import torch

from labml import experiment
from labml_nn.capsule_networks.mnist import Configs

Create an experiment

experiment.create(name="capsule_networks")

Initialize Capsule Network configurations

conf = Configs()

Set experiment configurations and assign a configurations dictionary to override configurations

experiment.configs(conf, {'optimizer.optimizer': 'Adam',
                         'optimizer.learning_rate': 1e-3,
                         'inner_iterations': 5})

Set PyTorch models for loading and saving

experiment.add_pytorch_models({'model': conf.model})

Start the experiment and run the training loop.

with experiment.start():
    conf.run()