Back to Models

Quantization-aware Training (QAT) for Object Detection with Model Garden

official/projects/qat/vision/docs/qat_tutorial.ipynb

2.20.06.4 KB
Original Source

Quantization-aware Training (QAT) for Object Detection with Model Garden

This tutorial demonstrates how to apply quantization-aware training (QAT) from a pre-trained checkpoint, export the checkpoint to a TFLite and run inference on an image, for object detection task, using Tensorflow Model Garden library.

Tensorflow Model Garden contains a collection of state-of-the-art models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.

In this tutorial, we will use MobileNetV2 backbone with RetinaNet framework as an example to walk you through the process of applying QAT. This assumes you have already trained a model using Tensorflow Model Garden.

Install Necessary Dependencies

python
!pip install -U tf-models-nightly

Import libraries

python
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

from six import BytesIO
from IPython import display
from urllib.request import urlopen

import numpy as np
import tensorflow as tf

import absl.logging
absl.logging.set_verbosity(absl.logging.ERROR)
tf.get_logger().setLevel(absl.logging.ERROR)

Download Pretrained Model

The model uses the implementation from the TensorFlow Model Garden GitHub repository, and achieves 23.3 mAP on COCO validation set. It uses a MobileNetV2 backbone and RetinaNet decoder on a 256x256 input image.

python
! curl https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i256_ckpt.tar.gz --output model.tar.gz
python
# Extract pretrained checkpoint.
! tar -xvzf model.tar.gz

Launch QAT Training

You can follow the training guideline to start QAT training using the pretrained checkpoint.

Export Model

After QAT training completes, we can export a SavedModel and convert it to a TFLite model. For demonstration purpose only, we download a QAT trained model checkpoint and work on it.

python
! curl https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i256_qat_ckpt.tar.gz --output model_qat.tar.gz
! tar -xvzf model_qat.tar.gz
python
! curl https://raw.githubusercontent.com/tensorflow/models/master/official/projects/qat/vision/configs/experiments/retinanet/coco_mobilenetv2_qat_tpu_e2e.yaml --output params.yaml
python
# Model export and convert.
# First export a SavedModel. Make sure batch_size=1 and input_type=tflite.
! python3 /usr/local/lib/python3.8/dist-packages/official/projects/qat/vision/serving/export_saved_model.py --experiment=retinanet_mobile_coco_qat --export_dir=${PWD}/mobilenetv2_ssd_i256_qat_savedmodel --checkpoint_path=${PWD}/mobilenetv2_ssd_i256_qat_ckpt --batch_size=1 --input_type=tflite --input_image_size=256,256 --config_file=${PWD}/params.yaml --params_override="task.quantization.pretrained_original_checkpoint='${PWD}/mobilenetv2_ssd_i256_ckpt/ckpt-277200'"

# Convert the SavedModel to TFLite
! python3 /usr/local/lib/python3.8/dist-packages/official/projects/qat/vision/serving/export_tflite.py --experiment=retinanet_mobile_coco_qat --saved_model_dir=${PWD}/mobilenetv2_ssd_i256_qat_savedmodel/saved_model --tflite_path=${PWD}/mobilenetv2_ssd_i256_qat_tflite --config_file=${PWD}/params.yaml --quant_type=qat

Run Inference

Now we will show how to use the converted TFLite model to do inference and obtain detection results. We provide our converted TFLite model that can be directly used for this.

python
# First download the TFLite model.
! curl https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/model_int8_qat.tflite --output model.tflite
python
# Defines helper function to download sample image.
def load_image_into_numpy_array(path, height, width):
  """Load an image from file into a numpy array.

  Puts image into numpy array to feed into tensorflow graph.
  Note that by convention we put it into a numpy array with shape
  (height, width, channels), where channels=3 for RGB.

  Args:
    path: the file path to the image

  Returns:
    uint8 numpy array with shape (height, width, 3)
  """
  image = None
  if(path.startswith('http')):
    response = urlopen(path)
    image_data = response.read()
    image_data = BytesIO(image_data)
    image = Image.open(image_data)
  else:
    image_data = tf.io.gfile.GFile(path, 'rb').read()
    image = Image.open(BytesIO(image_data))

  (im_width, im_height) = image.size
  image = image.resize((height, width))

  image = np.array(image.getdata()).reshape(
      (1, height, width, 3)).astype(np.uint8)

  return image

Download a Sample Image

python
image_path = "https://djl.ai/examples/src/test/resources/dog_bike_car.jpg"
image_array = load_image_into_numpy_array(image_path, 256, 256)
image_array.shape
python
Image.fromarray(image_array[0])

Run Inference on Sample Image

python
# Load TFLite model.
tflite_path = 'model.tflite'

with tf.io.gfile.GFile(tflite_path, 'rb') as f:
  tflite_model = f.read()

interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]['index'], image_array)

interpreter.invoke()

# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
outputs = []
for i in range(len(output_details)):
  outputs.append(interpreter.get_tensor(output_details[i]['index']))
python
# The final outputs is a list of [detection_boxes, num_detections, detection_scores, detection_classes, image_info].
outputs

Visualize Detection Outputs

python
plt.imshow(Image.fromarray(image_array[0]))

scores = outputs[2]
num_detection = outputs[1]
boxes = outputs[0]
classes = outputs[3]

# We only show boxes that have detection score larger than 0.5.
threshold = 0.5

for i in range(int(num_detection[0])):
  if scores[0, i] > threshold:
    ax = plt.gca()
    rect = patches.Rectangle((boxes[0, i, 1], boxes[0, i, 0]), boxes[0, i, 3] - boxes[0, i, 1], boxes[0, i, 2] - boxes[0, i, 0], linewidth=1,edgecolor='r',facecolor='none')
    ax.add_patch(rect)