examples/transfer-learning/README.md
This tutorial follows the lines of the PyTorch Transfer Learning Tutorial.
We will use transfer learning to leverage a pretrained ResNet model on a small dataset. This dataset is made of images of ants and bees that we want to classify, there are roughly 240 training images and 150 validation images each of them of size 224x224. The dataset is so small that training even the simplest convolutional neural network on it would be very difficult.
Instead the original tutorial proposes two alternatives to train the classifier.
We will focus on the second alternative but first we need to get the code building and running and we also have to download the dataset and pretrained weights.
Run the following commands to download the latest tch-rs version and run the tests, this installs the CPU version of libtorch if necessary.
git clone https://github.com/LaurentMazare/tch-rs.git
cd tch-rs
cargo test
The ants and bees dataset can be downloaded here. You can download the weights for a ResNet-18 network pretrained on ImageNet, resnet18.ot.
Once this is done and the dataset has been extracted we can build and run the code with:
cargo run --example transfer-learning resnet18.ot hymenoptera_data
Let us now have a look at the code from main.rs.
The dataset is loaded via some helper functions.
let dataset = imagenet::load_from_dir(dataset_dir)?;
println!("{:?}", dataset);
The println! macro prints the dimensions of the tensors that have
been created. For training the tensor has shape 211x3x224x224, this
corresponds to 211 images of height and width both 224 with 3 channels
(PyTorch uses the NCHW ordering for image data). The testing image
tensor has dimensions 127 so there are 127 images with the
same size as used in training.
The pixel data from the dataset is converted to features by running a pre-trained ResNet model.
let mut vs = tch::nn::VarStore::new(tch::Device::Cpu);
let net = resnet::resnet18_no_final_layer(&vs.root());
vs.load(weights)?;
let train_images = tch::no_grad(|| dataset.train_images.apply_t(&net, false));
let test_images = tch::no_grad(|| dataset.test_images.apply_t(&net, false));
This snippet performs the following steps:
vs is created to hold the network weights.vs.load(weights) loads the weights stored in a given file and copy their values
to some tensors. Tensors are named in the serialized file in a way that matches
the names we used when creating the ResNet model.apply
performs a forward pass on the model and returns the resulting tensor. In this
case the result is a vector of 512 values per sample.
The no_grad closure informs PyTorch that there is no need to keep a graph of the
forward pass as we do not plan on asking for gradients.Now that we have precomputed the output of the ResNet model on our training and testing images we will train a linear binary classifier to recognize ants vs bees.
We start by defining a model, for this we need a variable store to hold the trainable variables.
let vs = tch::nn::VarStore::new(tch::Device::Cpu);
let linear = nn::linear(vs.root(), 512, dataset.labels, Default::default());
We will use stochastic gradient descent to minimize the cross-entropy loss
on the classification task. To do this we create a sgd optimizer and then
iterate on the training dataset. After each epoch the accuracy is computed
on the testing set and printed.
let sgd = nn::Sgd::default().build(&vs, 1e-3)?;
for epoch_idx in 1..1001 {
let predicted = train_images.apply(&linear);
let loss = predicted.cross_entropy_for_logits(&dataset.train_labels);
sgd.backward_step(&loss);
let test_accuracy = test_images
.apply(&linear)
.accuracy_for_logits(&dataset.test_labels);
println!("{} {:.2}%", epoch_idx, 100. * f64::from(test_accuracy));
}
On each training step the model output is computed through a forward pass. The cross-entropy loss is then evaluated on the resulting logits using the training labels. The backward pass then evaluates gradients for the trainable variables of our model and these variables are updated by the optimizer.
let predicted = train_images.apply(&linear);
let loss = predicted.cross_entropy_for_logits(&dataset.train_labels);
sgd.backward_step(&loss);
After each epoch the accuracy is evaluated on the testing set and printed out.
let test_accuracy = test_images
.apply(&linear)
.accuracy_for_logits(&dataset.test_labels);
println!("{} {:.2}%", epoch_idx, 100. * f64::from(test_accuracy));
This should result in a 94.5% accuracy on the testing set.
The whole code for this example can be found in main.rs.