Back to Mxnet

Licensed to the Apache Software Foundation (ASF) under one

example/bi-lstm-sort/bi-lstm-sort.ipynb

1.9.15.6 KB
Original Source

Licensed to the Apache Software Foundation (ASF) under one

or more contributor license agreements. See the NOTICE file

distributed with this work for additional information

regarding copyright ownership. The ASF licenses this file

to you under the Apache License, Version 2.0 (the

"License"); you may not use this file except in compliance

with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing,

software distributed under the License is distributed on an

"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY

KIND, either express or implied. See the License for the

specific language governing permissions and limitations

under the License.

Using a bi-lstm to sort a sequence of integers

python
import random
import string

import mxnet as mx
from mxnet import gluon, np
import numpy as onp

Data Preparation

python
max_num = 999
dataset_size = 60000
seq_len = 5
split = 0.8
batch_size = 512
ctx = mx.gpu() if mx.device.num_gpus() > 0 else mx.cpu()

We are getting a dataset of dataset_size sequences of integers of length seq_len between 0 and max_num. We use split*100% of them for training and the rest for testing.

For example:

50 10 200 999 30

Should return

10 30 50 200 999

python
X = mx.np.random.uniform(low=0, high=max_num, size=(dataset_size, seq_len)).astype('int32').asnumpy()
Y = X.copy()
Y.sort() #Let's sort X to get the target
python
print("Input {}\nTarget {}".format(X[0].tolist(), Y[0].tolist()))

For the purpose of training, we encode the input as characters rather than numbers

python
vocab = string.digits + " "
print(vocab)
vocab_idx = { c:i for i,c in enumerate(vocab)}
print(vocab_idx)

We write a transform that will convert our numbers into text of maximum length max_len, and one-hot encode the characters. For example:

"30 10" corresponding indices are [3, 0, 10, 1, 0]

We then one hot encode that and get a matrix representation of our input. We don't need to encode our target as the loss we are going to use support sparse labels

python
max_len = len(str(max_num))*seq_len+(seq_len-1)
print("Maximum length of the string: %s" % max_len)
python
def transform(x, y):
    x_string = ' '.join(map(str, x.tolist()))
    x_string_padded = x_string + ' '*(max_len-len(x_string))
    x = [vocab_idx[c] for c in x_string_padded]
    y_string = ' '.join(map(str, y.tolist()))
    y_string_padded = y_string + ' '*(max_len-len(y_string))
    y = [vocab_idx[c] for c in y_string_padded]
    return mx.npx.one_hot(mx.nd.array(x), len(vocab)), mx.np.array(y)
python
split_idx = int(split*len(X))
train_dataset = gluon.data.ArrayDataset(X[:split_idx], Y[:split_idx]).transform(transform)
test_dataset = gluon.data.ArrayDataset(X[split_idx:], Y[split_idx:]).transform(transform)
python
print("Input {}".format(X[0]))
print("Transformed data Input {}".format(train_dataset[0][0]))
print("Target {}".format(Y[0]))
print("Transformed data Target {}".format(train_dataset[0][1]))
python
train_data = gluon.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=20, last_batch='rollover')
test_data = gluon.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=5, last_batch='rollover')

Creating the network

python
net = gluon.nn.HybridSequential()
net.add(
    gluon.rnn.LSTM(hidden_size=128, num_layers=2, layout='NTC', bidirectional=True),
    gluon.nn.Dense(len(vocab), flatten=False)
)
python
net.initialize(mx.init.Xavier(), ctx=ctx)
python
loss = gluon.loss.SoftmaxCELoss()

We use a learning rate schedule to improve the convergence of the model

python
schedule = mx.lr_scheduler.FactorScheduler(step=len(train_data)*10, factor=0.75)
schedule.base_lr = 0.01
python
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate':0.01, 'lr_scheduler':schedule})

Training loop

python
epochs = 100
for e in range(epochs):
    epoch_loss = 0.
    for i, (data, label) in enumerate(train_data):
        data = data.as_in_context(ctx)
        label = label.as_in_context(ctx)

        with mx.autograd.record():
            output = net(data)
            l = loss(output, label)

        l.backward()
        trainer.step(data.shape[0])
    
        epoch_loss += l.mean()
        
    print("Epoch [{}] Loss: {}, LR {}".format(e, epoch_loss.item()/(i+1), trainer.learning_rate))

Testing

We get a random element from the testing set

python
n = random.randint(0, len(test_data)-1)

x_orig = X[split_idx+n]
y_orig = Y[split_idx+n]
python
def get_pred(x):
    x, _ = transform(x, x)
    output = net(mx.np.expand_dims(x.to_device(ctx), axis=0))

    # Convert output back to string
    pred = ''.join([vocab[int(o)] for o in output[0].argmax(axis=1).asnumpy().tolist()])
    return pred

Printing the result

python
x_ = ' '.join(map(str,x_orig))
label = ' '.join(map(str,y_orig))
print("X         {}\nPredicted {}\nLabel     {}".format(x_, get_pred(x_orig), label))

We can also pick our own example, and the network manages to sort it without problem:

python
print(get_pred(onp.array([500, 30, 999, 10, 130])))

The model has even learned to generalize to examples not on the training set

python
print("Only four numbers:", get_pred(onp.array([105, 302, 501, 202])))

However we can see it has trouble with other edge cases:

python
print("Small digits:", get_pred(onp.array([10, 3, 5, 2, 8])))
print("Small digits, 6 numbers:", get_pred(onp.array([10, 33, 52, 21, 82, 10])))

This could be improved by adjusting the training dataset accordingly