example/multi-task/multi-task-learning.ipynb
This is a simple example to show how to use mxnet for multi-task learning.
The network is jointly going to learn whether a number is odd or even and to actually recognize the digit.
For example
etc
In this example we don't expect the tasks to contribute to each other much, but for example multi-task learning has been successfully applied to the domain of image captioning. In A Multi-task Learning Approach for Image Captioning by Wei Zhao, Benyou Wang, Jianbo Ye, Min Yang, Zhou Zhao, Ruotian Luo, Yu Qiao, they train a network to jointly classify images and generate text captions
import logging
import random
import time
import matplotlib.pyplot as plt
import mxnet as mx
from mxnet import gluon, np, npx, autograd
import numpy as onp
batch_size = 128
epochs = 5
ctx = mx.gpu() if mx.device.num_gpus() > 0 else mx.cpu()
lr = 0.01
We get the traditionnal MNIST dataset and add a new label to the existing one. For each digit we return a new label that stands for Odd or Even
train_dataset = gluon.data.vision.MNIST(train=True)
test_dataset = gluon.data.vision.MNIST(train=False)
def transform(x,y):
x = x.transpose((2,0,1)).astype('float32')/255.
y1 = y
y2 = y % 2 #odd or even
return x, onp.float32(y1), onp.float32(y2)
We assign the transform to the original dataset
train_dataset_t = train_dataset.transform(transform)
test_dataset_t = test_dataset.transform(transform)
We load the datasets DataLoaders
train_data = gluon.data.DataLoader(train_dataset_t, shuffle=True, last_batch='rollover', batch_size=batch_size, num_workers=5)
test_data = gluon.data.DataLoader(test_dataset_t, shuffle=False, last_batch='rollover', batch_size=batch_size, num_workers=5)
print("Input shape: {}, Target Labels: {}".format(train_dataset[0][0].shape, train_dataset_t[0][1:]))
The output of the featurization is passed to two different outputs layers
class MultiTaskNetwork(gluon.HybridBlock):
def __init__(self):
super(MultiTaskNetwork, self).__init__()
self.shared = gluon.nn.HybridSequential()
self.shared.add(
gluon.nn.Dense(128, activation='relu'),
gluon.nn.Dense(64, activation='relu'),
gluon.nn.Dense(10, activation='relu')
)
self.output1 = gluon.nn.Dense(10) # Digist recognition
self.output2 = gluon.nn.Dense(1) # odd or even
def forward(self, x):
y = self.shared(x)
output1 = self.output1(y)
output2 = self.output2(y)
return output1, output2
We can use two different losses, one for each output
loss_digits = gluon.loss.SoftmaxCELoss()
loss_odd_even = gluon.loss.SigmoidBCELoss()
We create and initialize the network
mx.np.random.seed(42)
random.seed(42)
net = MultiTaskNetwork()
net.initialize(mx.init.Xavier(), ctx=ctx)
net.hybridize() # hybridize for speed
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate':lr})
We need to evaluate the accuracy of each task separately
def evaluate_accuracy(net, data_iterator):
acc_digits = mx.gluon.metric.Accuracy(name='digits')
acc_odd_even = mx.gluon.metric.Accuracy(name='odd_even')
for i, (data, label_digit, label_odd_even) in enumerate(data_iterator):
data = data.to_device(ctx)
label_digit = label_digit.to_device(ctx)
label_odd_even = label_odd_even.to_device(ctx).reshape(-1,1)
output_digit, output_odd_even = net(data)
acc_digits.update(label_digit, npx.softmax(output_digit))
acc_odd_even.update(label_odd_even, npx.sigmoid(output_odd_even) > 0.5)
return acc_digits.get(), acc_odd_even.get()
We need to balance the contribution of each loss to the overall training and do so by tuning this alpha parameter within [0,1].
alpha = 0.5 # Combine losses factor
for e in range(epochs):
# Accuracies for each task
acc_digits = mx.gluon.metric.Accuracy(name='digits')
acc_odd_even = mx.gluon.metric.Accuracy(name='odd_even')
# Accumulative losses
l_digits_ = 0.
l_odd_even_ = 0.
for i, (data, label_digit, label_odd_even) in enumerate(train_data):
data = data.to_device(ctx)
label_digit = label_digit.to_device(ctx)
label_odd_even = label_odd_even.to_device(ctx).reshape(-1,1)
with autograd.record():
output_digit, output_odd_even = net(data)
l_digits = loss_digits(output_digit, label_digit)
l_odd_even = loss_odd_even(output_odd_even, label_odd_even)
# Combine the loss of each task
l_combined = (1-alpha)*l_digits + alpha*l_odd_even
l_combined.backward()
trainer.step(data.shape[0])
l_digits_ += l_digits.mean()
l_odd_even_ += l_odd_even.mean()
acc_digits.update(label_digit, npx.softmax(output_digit))
acc_odd_even.update(label_odd_even, npx.sigmoid(output_odd_even) > 0.5)
print("Epoch [{}], Acc Digits {:.4f} Loss Digits {:.4f}".format(
e, acc_digits.get()[1], l_digits_.item()/(i+1)))
print("Epoch [{}], Acc Odd/Even {:.4f} Loss Odd/Even {:.4f}".format(
e, acc_odd_even.get()[1], l_odd_even_.item()/(i+1)))
print("Epoch [{}], Testing Accuracies {}".format(e, evaluate_accuracy(net, test_data)))
def get_random_data():
idx = random.randint(0, len(test_dataset))
img = test_dataset[idx][0]
data, _, _ = test_dataset_t[idx]
data = np.expand_dims(data.to_device(ctx), axis=0)
plt.imshow(img.squeeze().asnumpy(), cmap='gray')
return data
data = get_random_data()
digit, odd_even = net(data)
digit = digit.argmax(axis=1)[0].asnumpy()
odd_even = (npx.sigmoid(odd_even)[0] > 0.5).asnumpy()
print("Predicted digit: {}, odd: {}".format(digit, odd_even))