official/projects/qat/vision/docs/qat_tutorial.ipynb
This tutorial demonstrates how to apply quantization-aware training (QAT) from a pre-trained checkpoint, export the checkpoint to a TFLite and run inference on an image, for object detection task, using Tensorflow Model Garden library.
Tensorflow Model Garden contains a collection of state-of-the-art models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.
In this tutorial, we will use MobileNetV2 backbone with RetinaNet framework as an example to walk you through the process of applying QAT. This assumes you have already trained a model using Tensorflow Model Garden.
!pip install -U tf-models-nightly
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
from six import BytesIO
from IPython import display
from urllib.request import urlopen
import numpy as np
import tensorflow as tf
import absl.logging
absl.logging.set_verbosity(absl.logging.ERROR)
tf.get_logger().setLevel(absl.logging.ERROR)
The model uses the implementation from the TensorFlow Model Garden GitHub repository, and achieves 23.3 mAP on COCO validation set. It uses a MobileNetV2 backbone and RetinaNet decoder on a 256x256 input image.
! curl https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i256_ckpt.tar.gz --output model.tar.gz
# Extract pretrained checkpoint.
! tar -xvzf model.tar.gz
You can follow the training guideline to start QAT training using the pretrained checkpoint.
After QAT training completes, we can export a SavedModel and convert it to a TFLite model. For demonstration purpose only, we download a QAT trained model checkpoint and work on it.
! curl https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i256_qat_ckpt.tar.gz --output model_qat.tar.gz
! tar -xvzf model_qat.tar.gz
! curl https://raw.githubusercontent.com/tensorflow/models/master/official/projects/qat/vision/configs/experiments/retinanet/coco_mobilenetv2_qat_tpu_e2e.yaml --output params.yaml
# Model export and convert.
# First export a SavedModel. Make sure batch_size=1 and input_type=tflite.
! python3 /usr/local/lib/python3.8/dist-packages/official/projects/qat/vision/serving/export_saved_model.py --experiment=retinanet_mobile_coco_qat --export_dir=${PWD}/mobilenetv2_ssd_i256_qat_savedmodel --checkpoint_path=${PWD}/mobilenetv2_ssd_i256_qat_ckpt --batch_size=1 --input_type=tflite --input_image_size=256,256 --config_file=${PWD}/params.yaml --params_override="task.quantization.pretrained_original_checkpoint='${PWD}/mobilenetv2_ssd_i256_ckpt/ckpt-277200'"
# Convert the SavedModel to TFLite
! python3 /usr/local/lib/python3.8/dist-packages/official/projects/qat/vision/serving/export_tflite.py --experiment=retinanet_mobile_coco_qat --saved_model_dir=${PWD}/mobilenetv2_ssd_i256_qat_savedmodel/saved_model --tflite_path=${PWD}/mobilenetv2_ssd_i256_qat_tflite --config_file=${PWD}/params.yaml --quant_type=qat
Now we will show how to use the converted TFLite model to do inference and obtain detection results. We provide our converted TFLite model that can be directly used for this.
# First download the TFLite model.
! curl https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/model_int8_qat.tflite --output model.tflite
# Defines helper function to download sample image.
def load_image_into_numpy_array(path, height, width):
"""Load an image from file into a numpy array.
Puts image into numpy array to feed into tensorflow graph.
Note that by convention we put it into a numpy array with shape
(height, width, channels), where channels=3 for RGB.
Args:
path: the file path to the image
Returns:
uint8 numpy array with shape (height, width, 3)
"""
image = None
if(path.startswith('http')):
response = urlopen(path)
image_data = response.read()
image_data = BytesIO(image_data)
image = Image.open(image_data)
else:
image_data = tf.io.gfile.GFile(path, 'rb').read()
image = Image.open(BytesIO(image_data))
(im_width, im_height) = image.size
image = image.resize((height, width))
image = np.array(image.getdata()).reshape(
(1, height, width, 3)).astype(np.uint8)
return image
image_path = "https://djl.ai/examples/src/test/resources/dog_bike_car.jpg"
image_array = load_image_into_numpy_array(image_path, 256, 256)
image_array.shape
Image.fromarray(image_array[0])
# Load TFLite model.
tflite_path = 'model.tflite'
with tf.io.gfile.GFile(tflite_path, 'rb') as f:
tflite_model = f.read()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]['index'], image_array)
interpreter.invoke()
# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
outputs = []
for i in range(len(output_details)):
outputs.append(interpreter.get_tensor(output_details[i]['index']))
# The final outputs is a list of [detection_boxes, num_detections, detection_scores, detection_classes, image_info].
outputs
plt.imshow(Image.fromarray(image_array[0]))
scores = outputs[2]
num_detection = outputs[1]
boxes = outputs[0]
classes = outputs[3]
# We only show boxes that have detection score larger than 0.5.
threshold = 0.5
for i in range(int(num_detection[0])):
if scores[0, i] > threshold:
ax = plt.gca()
rect = patches.Rectangle((boxes[0, i, 1], boxes[0, i, 0]), boxes[0, i, 3] - boxes[0, i, 1], boxes[0, i, 2] - boxes[0, i, 0], linewidth=1,edgecolor='r',facecolor='none')
ax.add_patch(rect)