doc/source/tune/examples/tune_mnist_keras.ipynb
(tune-mnist-keras)=
:align: center
:alt: Keras & TensorFlow Logo
:height: 120px
:target: https://keras.io
:backlinks: none
:local: true
pip install "ray[tune]" tensorflow==2.18.0 filelockimport os
from filelock import FileLock
from tensorflow.keras.datasets import mnist
from ray import tune
from ray.tune.schedulers import AsyncHyperBandScheduler
from ray.tune.integration.keras import TuneReportCheckpointCallback
def train_mnist(config):
# https://github.com/tensorflow/tensorflow/issues/32159
import tensorflow as tf
batch_size = 128
num_classes = 10
epochs = 12
with FileLock(os.path.expanduser("~/.data.lock")):
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(config["hidden"], activation="relu"),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(num_classes, activation="softmax"),
]
)
model.compile(
loss="sparse_categorical_crossentropy",
optimizer=tf.keras.optimizers.SGD(learning_rate=config["learning_rate"], momentum=config["momentum"]),
metrics=["accuracy"],
)
model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=epochs,
verbose=0,
validation_data=(x_test, y_test),
callbacks=[TuneReportCheckpointCallback(metrics={"accuracy": "accuracy"})],
)
def tune_mnist():
sched = AsyncHyperBandScheduler(
time_attr="training_iteration", max_t=400, grace_period=20
)
tuner = tune.Tuner(
tune.with_resources(train_mnist, resources={"cpu": 2, "gpu": 0}),
tune_config=tune.TuneConfig(
metric="accuracy",
mode="max",
scheduler=sched,
num_samples=10,
),
run_config=tune.RunConfig(
name="exp",
stop={"accuracy": 0.99},
),
param_space={
"threads": 2,
"learning_rate": tune.uniform(0.001, 0.1),
"momentum": tune.uniform(0.1, 0.9),
"hidden": tune.randint(32, 512),
},
)
results = tuner.fit()
return results
results = tune_mnist()
print(f"Best hyperparameters found were: {results.get_best_result().config} | Accuracy: {results.get_best_result().metrics['accuracy']}")
This should output something like:
Best hyperparameters found were: {'threads': 2, 'learning_rate': 0.07607440973606909, 'momentum': 0.7715363277240616, 'hidden': 452} | Accuracy: 0.98458331823349
/tune/examples/includes/pbt_memnn_example: Example of training a Memory NN on bAbI with Keras using PBT./tune/examples/includes/tf_mnist_example: Converts the Advanced TF2.0 MNIST example to use Tune
with the Trainable. This uses tf.function.
Original code from tensorflow: https://www.tensorflow.org/tutorials/quickstart/advanced/tune/examples/includes/pbt_tune_cifar10_with_keras:
A contributed example of tuning a Keras model on CIFAR10 with the PopulationBasedTraining scheduler.