Back to Mxnet

Licensed to the Apache Software Foundation (ASF) under one

example/multi-task/multi-task-learning.ipynb

1.9.17.0 KB
Original Source

Licensed to the Apache Software Foundation (ASF) under one

or more contributor license agreements. See the NOTICE file

distributed with this work for additional information

regarding copyright ownership. The ASF licenses this file

to you under the Apache License, Version 2.0 (the

"License"); you may not use this file except in compliance

with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing,

software distributed under the License is distributed on an

"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY

KIND, either express or implied. See the License for the

specific language governing permissions and limitations

under the License.

Multi-Task Learning Example

This is a simple example to show how to use mxnet for multi-task learning.

The network is jointly going to learn whether a number is odd or even and to actually recognize the digit.

For example

  • 1 : 1 and odd
  • 2 : 2 and even
  • 3 : 3 and odd

etc

In this example we don't expect the tasks to contribute to each other much, but for example multi-task learning has been successfully applied to the domain of image captioning. In A Multi-task Learning Approach for Image Captioning by Wei Zhao, Benyou Wang, Jianbo Ye, Min Yang, Zhou Zhao, Ruotian Luo, Yu Qiao, they train a network to jointly classify images and generate text captions

python
import logging
import random
import time

import matplotlib.pyplot as plt
import mxnet as mx
from mxnet import gluon, np, npx, autograd
import numpy as onp

Parameters

python
batch_size = 128
epochs = 5
ctx = mx.gpu() if mx.device.num_gpus() > 0 else mx.cpu()
lr = 0.01

Data

We get the traditionnal MNIST dataset and add a new label to the existing one. For each digit we return a new label that stands for Odd or Even

python
train_dataset = gluon.data.vision.MNIST(train=True)
test_dataset = gluon.data.vision.MNIST(train=False)
python
def transform(x,y):
    x = x.transpose((2,0,1)).astype('float32')/255.
    y1 = y
    y2 = y % 2 #odd or even
    return x, onp.float32(y1), onp.float32(y2)

We assign the transform to the original dataset

python
train_dataset_t = train_dataset.transform(transform)
test_dataset_t = test_dataset.transform(transform)

We load the datasets DataLoaders

python
train_data = gluon.data.DataLoader(train_dataset_t, shuffle=True, last_batch='rollover', batch_size=batch_size, num_workers=5)
test_data = gluon.data.DataLoader(test_dataset_t, shuffle=False, last_batch='rollover', batch_size=batch_size, num_workers=5)
python
print("Input shape: {}, Target Labels: {}".format(train_dataset[0][0].shape, train_dataset_t[0][1:]))

Multi-task Network

The output of the featurization is passed to two different outputs layers

python
class MultiTaskNetwork(gluon.HybridBlock):
    
    def __init__(self):
        super(MultiTaskNetwork, self).__init__()
        
        self.shared = gluon.nn.HybridSequential()
        self.shared.add(
            gluon.nn.Dense(128, activation='relu'),
            gluon.nn.Dense(64, activation='relu'),
            gluon.nn.Dense(10, activation='relu')
        )
        self.output1 = gluon.nn.Dense(10) # Digist recognition
        self.output2 = gluon.nn.Dense(1) # odd or even

        
    def forward(self, x):
        y = self.shared(x)
        output1 = self.output1(y)
        output2 = self.output2(y)
        return output1, output2

We can use two different losses, one for each output

python
loss_digits = gluon.loss.SoftmaxCELoss()
loss_odd_even = gluon.loss.SigmoidBCELoss()

We create and initialize the network

python
mx.np.random.seed(42)
random.seed(42)
python
net = MultiTaskNetwork()
python
net.initialize(mx.init.Xavier(), ctx=ctx)
net.hybridize() # hybridize for speed
python
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate':lr})

Evaluate Accuracy

We need to evaluate the accuracy of each task separately

python
def evaluate_accuracy(net, data_iterator):
    acc_digits = mx.gluon.metric.Accuracy(name='digits')
    acc_odd_even = mx.gluon.metric.Accuracy(name='odd_even')
    
    for i, (data, label_digit, label_odd_even) in enumerate(data_iterator):
        data = data.to_device(ctx)
        label_digit = label_digit.to_device(ctx)
        label_odd_even = label_odd_even.to_device(ctx).reshape(-1,1)

        output_digit, output_odd_even = net(data)
        
        acc_digits.update(label_digit, npx.softmax(output_digit))
        acc_odd_even.update(label_odd_even, npx.sigmoid(output_odd_even) > 0.5)
    return acc_digits.get(), acc_odd_even.get()

Training Loop

We need to balance the contribution of each loss to the overall training and do so by tuning this alpha parameter within [0,1].

python
alpha = 0.5 # Combine losses factor
python
for e in range(epochs):
    # Accuracies for each task
    acc_digits = mx.gluon.metric.Accuracy(name='digits')
    acc_odd_even = mx.gluon.metric.Accuracy(name='odd_even')
    # Accumulative losses
    l_digits_ = 0.
    l_odd_even_ = 0. 
    
    for i, (data, label_digit, label_odd_even) in enumerate(train_data):
        data = data.to_device(ctx)
        label_digit = label_digit.to_device(ctx)
        label_odd_even = label_odd_even.to_device(ctx).reshape(-1,1)
        
        with autograd.record():
            output_digit, output_odd_even = net(data)
            l_digits = loss_digits(output_digit, label_digit)
            l_odd_even = loss_odd_even(output_odd_even, label_odd_even)

            # Combine the loss of each task
            l_combined = (1-alpha)*l_digits + alpha*l_odd_even
            
        l_combined.backward()
        trainer.step(data.shape[0])
        
        l_digits_ += l_digits.mean()
        l_odd_even_ += l_odd_even.mean()
        acc_digits.update(label_digit, npx.softmax(output_digit))
        acc_odd_even.update(label_odd_even, npx.sigmoid(output_odd_even) > 0.5)
        
    print("Epoch [{}], Acc Digits   {:.4f} Loss Digits   {:.4f}".format(
        e, acc_digits.get()[1], l_digits_.item()/(i+1)))
    print("Epoch [{}], Acc Odd/Even {:.4f} Loss Odd/Even {:.4f}".format(
        e, acc_odd_even.get()[1], l_odd_even_.item()/(i+1)))
    print("Epoch [{}], Testing Accuracies {}".format(e, evaluate_accuracy(net, test_data)))
        

Testing

python
def get_random_data():
    idx = random.randint(0, len(test_dataset))

    img = test_dataset[idx][0]
    data, _, _ = test_dataset_t[idx]
    data = np.expand_dims(data.to_device(ctx), axis=0)

    plt.imshow(img.squeeze().asnumpy(), cmap='gray')
    
    return data
python
data = get_random_data()

digit, odd_even = net(data)

digit = digit.argmax(axis=1)[0].asnumpy()
odd_even = (npx.sigmoid(odd_even)[0] > 0.5).asnumpy()

print("Predicted digit: {}, odd: {}".format(digit, odd_even))