site/en/tutorials/images/transfer_learning_with_hub.ipynb
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
TensorFlow Hub is a repository of pre-trained TensorFlow models.
This tutorial demonstrates how to:
tf.keras.import numpy as np
import time
import PIL.Image as Image
import matplotlib.pylab as plt
import tensorflow as tf
import tensorflow_hub as hub
import datetime
%load_ext tensorboard
You'll start by using a classifier model pre-trained on the ImageNet benchmark dataset—no initial training required!
Select a <a href="https://arxiv.org/abs/1801.04381" class="external">MobileNetV2</a> pre-trained model from TensorFlow Hub and wrap it as a Keras layer with hub.KerasLayer. Any <a href="https://tfhub.dev/s?q=tf2&module-type=image-classification/" class="external">compatible image classifier model</a> from TensorFlow Hub will work here, including the examples provided in the drop-down below.
mobilenet_v2 ="https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4"
inception_v3 = "https://tfhub.dev/google/imagenet/inception_v3/classification/5"
classifier_model = mobilenet_v2 #@param ["mobilenet_v2", "inception_v3"] {type:"raw"}
IMAGE_SHAPE = (224, 224)
classifier = tf.keras.Sequential([
hub.KerasLayer(classifier_model, input_shape=IMAGE_SHAPE+(3,))
])
Download a single image to try the model on:
grace_hopper = tf.keras.utils.get_file('image.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg')
grace_hopper = Image.open(grace_hopper).resize(IMAGE_SHAPE)
grace_hopper
grace_hopper = np.array(grace_hopper)/255.0
grace_hopper.shape
Add a batch dimension (with np.newaxis) and pass the image to the model:
result = classifier.predict(grace_hopper[np.newaxis, ...])
result.shape
The result is a 1001-element vector of logits, rating the probability of each class for the image.
The top class ID can be found with tf.math.argmax:
predicted_class = tf.math.argmax(result[0], axis=-1)
predicted_class
Take the predicted_class ID (such as 653) and fetch the ImageNet dataset labels to decode the predictions:
labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
imagenet_labels = np.array(open(labels_path).read().splitlines())
plt.imshow(grace_hopper)
plt.axis('off')
predicted_class_name = imagenet_labels[predicted_class]
_ = plt.title("Prediction: " + predicted_class_name.title())
But what if you want to create a custom classifier using your own dataset that has classes that aren't included in the original ImageNet dataset (that the pre-trained model was trained on)?
To do that, you can:
In this example, you will use the TensorFlow flowers dataset:
import pathlib
data_file = tf.keras.utils.get_file(
'flower_photos.tgz',
'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
cache_dir='.',
extract=True)
data_root = pathlib.Path(data_file).with_suffix('')
First, load this data into the model using the image data off disk with tf.keras.utils.image_dataset_from_directory, which will generate a tf.data.Dataset:
batch_size = 32
img_height = 224
img_width = 224
train_ds = tf.keras.utils.image_dataset_from_directory(
str(data_root),
validation_split=0.2,
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size
)
val_ds = tf.keras.utils.image_dataset_from_directory(
str(data_root),
validation_split=0.2,
subset="validation",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size
)
The flowers dataset has five classes:
class_names = np.array(train_ds.class_names)
print(class_names)
Second, because TensorFlow Hub's convention for image models is to expect float inputs in the [0, 1] range, use the tf.keras.layers.Rescaling preprocessing layer to achieve this.
Note: You could also include the tf.keras.layers.Rescaling layer inside the model. Refer to the Working with preprocessing layers guide for a discussion of the tradeoffs.
normalization_layer = tf.keras.layers.Rescaling(1./255)
train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y)) # Where x—images, y—labels.
val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y)) # Where x—images, y—labels.
Third, finish the input pipeline by using buffered prefetching with Dataset.prefetch, so you can yield the data from disk without I/O blocking issues.
These are some of the most important tf.data methods you should use when loading data. Interested readers can learn more about them, as well as how to cache data to disk and other techniques, in the Better performance with the tf.data API guide.
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
for image_batch, labels_batch in train_ds:
print(image_batch.shape)
print(labels_batch.shape)
break
Now, run the classifier on an image batch:
result_batch = classifier.predict(train_ds)
predicted_class_names = imagenet_labels[tf.math.argmax(result_batch, axis=-1)]
predicted_class_names
Check how these predictions line up with the images:
plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
plt.subplot(6,5,n+1)
plt.imshow(image_batch[n])
plt.title(predicted_class_names[n])
plt.axis('off')
_ = plt.suptitle("ImageNet predictions")
Note: all images are licensed CC-BY, creators are listed in the LICENSE.txt file.
The results are far from perfect, but reasonable considering that these are not the classes the model was trained for (except for "daisy").
TensorFlow Hub also distributes models without the top classification layer. These can be used to easily perform transfer learning.
Select a <a href="https://arxiv.org/abs/1801.04381" class="external">MobileNetV2</a> pre-trained model <a href="https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4" class="external">from TensorFlow Hub</a>. Any <a href="https://tfhub.dev/s?module-type=image-feature-vector&q=tf2" class="external">compatible image feature vector model</a> from TensorFlow Hub will work here, including the examples from the drop-down menu.
mobilenet_v2 = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4"
inception_v3 = "https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4"
feature_extractor_model = mobilenet_v2 #@param ["mobilenet_v2", "inception_v3"] {type:"raw"}
Create the feature extractor by wrapping the pre-trained model as a Keras layer with hub.KerasLayer. Use the trainable=False argument to freeze the variables, so that the training only modifies the new classifier layer:
feature_extractor_layer = hub.KerasLayer(
feature_extractor_model,
input_shape=(224, 224, 3),
trainable=False)
The feature extractor returns a 1280-long vector for each image (the image batch size remains at 32 in this example):
feature_batch = feature_extractor_layer(image_batch)
print(feature_batch.shape)
To complete the model, wrap the feature extractor layer in a tf.keras.Sequential model and add a fully-connected layer for classification:
num_classes = len(class_names)
model = tf.keras.Sequential([
feature_extractor_layer,
tf.keras.layers.Dense(num_classes)
])
model.summary()
predictions = model(image_batch)
predictions.shape
Use Model.compile to configure the training process and add a tf.keras.callbacks.TensorBoard callback to create and store logs:
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['acc'])
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=log_dir,
histogram_freq=1) # Enable histogram computation for every epoch.
Now use the Model.fit method to train the model.
To keep this example short, you'll be training for just 10 epochs. To visualize the training progress in TensorBoard later, create and store logs an a TensorBoard callback.
NUM_EPOCHS = 10
history = model.fit(train_ds,
validation_data=val_ds,
epochs=NUM_EPOCHS,
callbacks=tensorboard_callback)
Start the TensorBoard to view how the metrics change with each epoch and to track other scalar values:
%tensorboard --logdir logs/fit
Obtain the ordered list of class names from the model predictions:
predicted_batch = model.predict(image_batch)
predicted_id = tf.math.argmax(predicted_batch, axis=-1)
predicted_label_batch = class_names[predicted_id]
print(predicted_label_batch)
Plot the model predictions:
plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
plt.subplot(6,5,n+1)
plt.imshow(image_batch[n])
plt.title(predicted_label_batch[n].title())
plt.axis('off')
_ = plt.suptitle("Model predictions")
Now that you've trained the model, export it as a SavedModel for reusing it later.
t = time.time()
export_path = "/tmp/saved_models/{}".format(int(t))
model.save(export_path)
export_path
Confirm that you can reload the SavedModel and that the model is able to output the same results:
reloaded = tf.keras.models.load_model(export_path)
result_batch = model.predict(image_batch)
reloaded_result_batch = reloaded.predict(image_batch)
abs(reloaded_result_batch - result_batch).max()
reloaded_predicted_id = tf.math.argmax(reloaded_result_batch, axis=-1)
reloaded_predicted_label_batch = class_names[reloaded_predicted_id]
print(reloaded_predicted_label_batch)
plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
plt.subplot(6,5,n+1)
plt.imshow(image_batch[n])
plt.title(reloaded_predicted_label_batch[n].title())
plt.axis('off')
_ = plt.suptitle("Model predictions")
You can use the SavedModel to load for inference or convert it to a TensorFlow Lite model (for on-device machine learning) or a TensorFlow.js model (for machine learning in JavaScript).
Discover more tutorials to learn how to use pre-trained models from TensorFlow Hub on image, text, audio, and video tasks.