examples/jit-train/README.md
To run this example, you will have to unzip the MNIST data
files in data/.
In this tutorial this is illustrated using a demo model that has not been pre-trained. We start by defining and serializing the model using the Python api. The resulting model file is later loaded from Rust and trained on the MNIST dataset.
There are various ways to create the Torch Script as detailed in the original tutorial.
Here we will use scripting. The following Python script runs the defined model on some random image and creates the Torch Script file based on this evaluation.
import torch
from torch.nn import Module
class DemoModule(Module):
def __init__(self):
super().__init__()
self.batch_norm = torch.nn.BatchNorm2d(1)
self.conv1 = torch.nn.Conv2d(1, 8, kernel_size=(5, 5), padding=(2, 2))
self.conv2 = torch.nn.Conv2d(8, 16, kernel_size=(5, 5), padding=(2, 2))
self.flatten = torch.nn.Flatten()
self.dropout = torch.nn.Dropout()
self.linear1 = torch.nn.Linear(16 * 28 * 28, 100)
self.linear2 = torch.nn.Linear(100, 10)
def forward(self, x):
x = self.batch_norm(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.flatten(x)
x = self.dropout(x)
x = self.linear1(x)
return self.linear2(x)
traced_script_module = torch.jit.script(DemoModule())
traced_script_module.save("model.pt")
The last line creates the model.pt Torch Script file which includes both the model
architecture and the weight values.
The model.pt file can then be loaded and trained from Rust.
fn train_and_save_model(dataset: &Dataset, device: Device) -> Result<()> {
let vs = VarStore::new(device);
let mut trainable = TrainableCModule::load("model.pt", vs.root())?;
trainable.set_train();
let initial_acc = trainable.batch_accuracy_for_logits(
&dataset.test_images,
&dataset.test_labels,
vs.device(),
1024,
);
println!("Initial accuracy: {:5.2}%", 100. * initial_acc);
let mut opt = Adam::default().build(&vs, 1e-4)?;
for epoch in 1..20 {
for (images, labels) in dataset
.train_iter(128)
.shuffle()
.to_device(vs.device())
.take(50)
{
let loss = trainable
.forward_t(&images, true)
.cross_entropy_for_logits(&labels);
opt.backward_step(&loss);
}
let test_accuracy = trainable.batch_accuracy_for_logits(
&dataset.test_images,
&dataset.test_labels,
vs.device(),
1024,
);
println!("epoch: {:4} test acc: {:5.2}%", epoch, 100. * test_accuracy,);
}
trainable.save("trained_model.pt")?;
Ok(())
}
After the model was trained, the updated Torch Script is saved to file trained_model.pt.
This can be used for executing the module later on either via Python or
using Rust as per the following.
fn load_trained_and_test_acc(dataset: &Dataset, device: Device) -> Result<()> {
let mut module = CModule::load_on_device("trained_model.pt", device)?;
module.set_eval();
let accuracy =
module.batch_accuracy_for_logits(&dataset.test_images, &dataset.test_labels, device, 1024);
println!("Updated accuracy: {:5.2}%", 100. * accuracy);
Ok(())
}
Below is the typical output that this example would produce.
Initial accuracy: 13.11%
epoch: 1 test acc: 87.11%
epoch: 2 test acc: 90.16%
epoch: 3 test acc: 90.31%
epoch: 4 test acc: 90.02%
epoch: 5 test acc: 90.66%
epoch: 6 test acc: 91.00%
epoch: 7 test acc: 91.48%
epoch: 8 test acc: 91.10%
epoch: 9 test acc: 91.13%
epoch: 10 test acc: 91.35%
epoch: 11 test acc: 91.04%
epoch: 12 test acc: 91.43%
epoch: 13 test acc: 91.34%
epoch: 14 test acc: 91.05%
epoch: 15 test acc: 91.23%
epoch: 16 test acc: 91.30%
epoch: 17 test acc: 91.45%
epoch: 18 test acc: 91.70%
epoch: 19 test acc: 91.44%
Updated accuracy: 91.93%