example/bi-lstm-sort/bi-lstm-sort.ipynb
import random
import string
import mxnet as mx
from mxnet import gluon, np
import numpy as onp
max_num = 999
dataset_size = 60000
seq_len = 5
split = 0.8
batch_size = 512
ctx = mx.gpu() if mx.device.num_gpus() > 0 else mx.cpu()
We are getting a dataset of dataset_size sequences of integers of length seq_len between 0 and max_num. We use split*100% of them for training and the rest for testing.
For example:
50 10 200 999 30
Should return
10 30 50 200 999
X = mx.np.random.uniform(low=0, high=max_num, size=(dataset_size, seq_len)).astype('int32').asnumpy()
Y = X.copy()
Y.sort() #Let's sort X to get the target
print("Input {}\nTarget {}".format(X[0].tolist(), Y[0].tolist()))
For the purpose of training, we encode the input as characters rather than numbers
vocab = string.digits + " "
print(vocab)
vocab_idx = { c:i for i,c in enumerate(vocab)}
print(vocab_idx)
We write a transform that will convert our numbers into text of maximum length max_len, and one-hot encode the characters. For example:
"30 10" corresponding indices are [3, 0, 10, 1, 0]
We then one hot encode that and get a matrix representation of our input. We don't need to encode our target as the loss we are going to use support sparse labels
max_len = len(str(max_num))*seq_len+(seq_len-1)
print("Maximum length of the string: %s" % max_len)
def transform(x, y):
x_string = ' '.join(map(str, x.tolist()))
x_string_padded = x_string + ' '*(max_len-len(x_string))
x = [vocab_idx[c] for c in x_string_padded]
y_string = ' '.join(map(str, y.tolist()))
y_string_padded = y_string + ' '*(max_len-len(y_string))
y = [vocab_idx[c] for c in y_string_padded]
return mx.npx.one_hot(mx.nd.array(x), len(vocab)), mx.np.array(y)
split_idx = int(split*len(X))
train_dataset = gluon.data.ArrayDataset(X[:split_idx], Y[:split_idx]).transform(transform)
test_dataset = gluon.data.ArrayDataset(X[split_idx:], Y[split_idx:]).transform(transform)
print("Input {}".format(X[0]))
print("Transformed data Input {}".format(train_dataset[0][0]))
print("Target {}".format(Y[0]))
print("Transformed data Target {}".format(train_dataset[0][1]))
train_data = gluon.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=20, last_batch='rollover')
test_data = gluon.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=5, last_batch='rollover')
net = gluon.nn.HybridSequential()
net.add(
gluon.rnn.LSTM(hidden_size=128, num_layers=2, layout='NTC', bidirectional=True),
gluon.nn.Dense(len(vocab), flatten=False)
)
net.initialize(mx.init.Xavier(), ctx=ctx)
loss = gluon.loss.SoftmaxCELoss()
We use a learning rate schedule to improve the convergence of the model
schedule = mx.lr_scheduler.FactorScheduler(step=len(train_data)*10, factor=0.75)
schedule.base_lr = 0.01
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate':0.01, 'lr_scheduler':schedule})
epochs = 100
for e in range(epochs):
epoch_loss = 0.
for i, (data, label) in enumerate(train_data):
data = data.as_in_context(ctx)
label = label.as_in_context(ctx)
with mx.autograd.record():
output = net(data)
l = loss(output, label)
l.backward()
trainer.step(data.shape[0])
epoch_loss += l.mean()
print("Epoch [{}] Loss: {}, LR {}".format(e, epoch_loss.item()/(i+1), trainer.learning_rate))
We get a random element from the testing set
n = random.randint(0, len(test_data)-1)
x_orig = X[split_idx+n]
y_orig = Y[split_idx+n]
def get_pred(x):
x, _ = transform(x, x)
output = net(mx.np.expand_dims(x.to_device(ctx), axis=0))
# Convert output back to string
pred = ''.join([vocab[int(o)] for o in output[0].argmax(axis=1).asnumpy().tolist()])
return pred
Printing the result
x_ = ' '.join(map(str,x_orig))
label = ' '.join(map(str,y_orig))
print("X {}\nPredicted {}\nLabel {}".format(x_, get_pred(x_orig), label))
We can also pick our own example, and the network manages to sort it without problem:
print(get_pred(onp.array([500, 30, 999, 10, 130])))
The model has even learned to generalize to examples not on the training set
print("Only four numbers:", get_pred(onp.array([105, 302, 501, 202])))
However we can see it has trouble with other edge cases:
print("Small digits:", get_pred(onp.array([10, 3, 5, 2, 8])))
print("Small digits, 6 numbers:", get_pred(onp.array([10, 33, 52, 21, 82, 10])))
This could be improved by adjusting the training dataset accordingly