Back to Models

Set Up

research/object_detection/colab_tutorials/centernet_on_device.ipynb

2.20.019.5 KB
Original Source

#Introduction

Welcome to the CenterNet on-device with TensorFlow Lite Colab. Here, we demonstrate how you can run a mobile-optimized version of the CenterNet architecture with TensorFlow Lite (a.k.a. TFLite).

Users can use this notebook as a reference for obtaining TFLite version of CenterNet for Object Detection or Keypoint detection. The code also shows how to perform pre-/post-processing & inference with TFLite's Python API.

NOTE: CenterNet support in TFLite is still experimental, and currently works with floating-point inference only.

Set Up

Libraries & Imports

!pip install tf-nightly
import os
import pathlib

# Clone the tensorflow models repository if it doesn't already exist
if "models" in pathlib.Path.cwd().parts:
  while "models" in pathlib.Path.cwd().parts:
    os.chdir('..')
elif not pathlib.Path('models').exists():
  !git clone --depth 1 https://github.com/tensorflow/models
# Install the Object Detection API
%%bash
cd models/research/
protoc object_detection/protos/*.proto --python_out=.
cp object_detection/packages/tf2/setup.py .
python -m pip install .
import matplotlib
import matplotlib.pyplot as plt

import os
import random
import io
import imageio
import glob
import scipy.misc
import numpy as np
from six import BytesIO
from PIL import Image, ImageDraw, ImageFont
from IPython.display import display, Javascript
from IPython.display import Image as IPyImage

import tensorflow as tf

from object_detection.utils import label_map_util
from object_detection.utils import config_util
from object_detection.utils import visualization_utils as viz_utils
from object_detection.utils import colab_utils
from object_detection.utils import config_util
from object_detection.builders import model_builder

%matplotlib inline

Test Image from COCO

We use a sample image from the COCO'17 validation dataset that contains people, to showcase inference with CenterNet.

# Download COCO'17 validation set for test image
%%bash
mkdir -p coco && cd coco
wget -q -N http://images.cocodataset.org/zips/val2017.zip
unzip -q -o val2017.zip && rm *.zip
cd ..
# Print the image we are going to test on as a sanity check.

def load_image_into_numpy_array(path):
  """Load an image from file into a numpy array.

  Puts image into numpy array to feed into tensorflow graph.
  Note that by convention we put it into a numpy array with shape
  (height, width, channels), where channels=3 for RGB.

  Args:
    path: a file path.

  Returns:
    uint8 numpy array with shape (img_height, img_width, 3)
  """
  img_data = tf.io.gfile.GFile(path, 'rb').read()
  image = Image.open(BytesIO(img_data))
  (im_width, im_height) = image.size
  return np.array(image.getdata()).reshape(
      (im_height, im_width, 3)).astype(np.uint8)

image_path = 'coco/val2017/000000013729.jpg'
plt.figure(figsize = (30, 20))
plt.imshow(load_image_into_numpy_array(image_path))

Utilities for Inference

The detect function shown below describes how input and output tensors from CenterNet (obtained in subsequent sections) can be processed. This logic can be ported to other languages depending on your application (for e.g. to Java for Android apps).

def detect(interpreter, input_tensor, include_keypoint=False):
  """Run detection on an input image.

  Args:
    interpreter: tf.lite.Interpreter
    input_tensor: A [1, height, width, 3] Tensor of type tf.float32.
      Note that height and width can be anything since the image will be
      immediately resized according to the needs of the model within this
      function.
    include_keypoint: True if model supports keypoints output. See
      https://cocodataset.org/#keypoints-2020

  Returns:
    A sequence containing the following output tensors:
      boxes: a numpy array of shape [N, 4]
      classes: a numpy array of shape [N]. Note that class indices are 
        1-based, and match the keys in the label map.
      scores: a numpy array of shape [N] or None.  If scores=None, then
        this function assumes that the boxes to be plotted are groundtruth
        boxes and plot all boxes as black with no classes or scores.
      category_index: a dict containing category dictionaries (each holding
        category index `id` and category name `name`) keyed by category 
        indices.
    If include_keypoints is True, the following are also returned:
      keypoints: (optional) a numpy array of shape [N, 17, 2] representing
        the yx-coordinates of the detection 17 COCO human keypoints
        (https://cocodataset.org/#keypoints-2020) in normalized image frame
        (i.e. [0.0, 1.0]). 
      keypoint_scores: (optional) a numpy array of shape [N, 17] representing the
        keypoint prediction confidence scores.
  """
  input_details = interpreter.get_input_details()
  output_details = interpreter.get_output_details()

  interpreter.set_tensor(input_details[0]['index'], input_tensor.numpy())

  interpreter.invoke()

  scores = interpreter.get_tensor(output_details[0]['index'])
  boxes = interpreter.get_tensor(output_details[1]['index'])
  num_detections = interpreter.get_tensor(output_details[2]['index'])
  classes = interpreter.get_tensor(output_details[3]['index'])

  if include_keypoint:
    kpts = interpreter.get_tensor(output_details[4]['index'])
    kpts_scores = interpreter.get_tensor(output_details[5]['index'])
    return boxes, classes, scores, num_detections, kpts, kpts_scores
  else:
    return boxes, classes, scores, num_detections

# Utility for visualizing results
def plot_detections(image_np,
                    boxes,
                    classes,
                    scores,
                    category_index,
                    keypoints=None,
                    keypoint_scores=None,
                    figsize=(12, 16),
                    image_name=None):
  """Wrapper function to visualize detections.

  Args:
    image_np: uint8 numpy array with shape (img_height, img_width, 3)
    boxes: a numpy array of shape [N, 4]
    classes: a numpy array of shape [N]. Note that class indices are 1-based,
      and match the keys in the label map.
    scores: a numpy array of shape [N] or None.  If scores=None, then
      this function assumes that the boxes to be plotted are groundtruth
      boxes and plot all boxes as black with no classes or scores.
    category_index: a dict containing category dictionaries (each holding
      category index `id` and category name `name`) keyed by category indices.
    keypoints: (optional) a numpy array of shape [N, 17, 2] representing the 
      yx-coordinates of the detection 17 COCO human keypoints
      (https://cocodataset.org/#keypoints-2020) in normalized image frame
      (i.e. [0.0, 1.0]). 
    keypoint_scores: (optional) anumpy array of shape [N, 17] representing the
      keypoint prediction confidence scores.
    figsize: size for the figure.
    image_name: a name for the image file.
  """

  keypoint_edges = [(0, 1),
        (0, 2),
        (1, 3),
        (2, 4),
        (0, 5),
        (0, 6),
        (5, 7),
        (7, 9),
        (6, 8),
        (8, 10),
        (5, 6),
        (5, 11),
        (6, 12),
        (11, 12),
        (11, 13),
        (13, 15),
        (12, 14),
        (14, 16)]
  image_np_with_annotations = image_np.copy()
  # Only visualize objects that get a score > 0.3.
  viz_utils.visualize_boxes_and_labels_on_image_array(
      image_np_with_annotations,
      boxes,
      classes,
      scores,
      category_index,
      keypoints=keypoints,
      keypoint_scores=keypoint_scores,
      keypoint_edges=keypoint_edges,
      use_normalized_coordinates=True,
      min_score_thresh=0.3)
  if image_name:
    plt.imsave(image_name, image_np_with_annotations)
  else:
    return image_np_with_annotations

Object Detection

Download Model from Detection Zoo

NOTE: Not all CenterNet models from the TF2 Detection Zoo work with TFLite, only the MobileNet-based version does.

# Get mobile-friendly CenterNet for Object Detection
# See TensorFlow 2 Detection Model Zoo for more details:
# https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md

%%bash
wget http://download.tensorflow.org/models/object_detection/tf2/20210210/centernet_mobilenetv2fpn_512x512_coco17_od.tar.gz
tar -xf centernet_mobilenetv2fpn_512x512_coco17_od.tar.gz
rm centernet_mobilenetv2fpn_512x512_coco17_od.tar.gz*

Now that we have downloaded the CenterNet model that uses MobileNet as a backbone, we can obtain a TensorFlow Lite model from it.

The downloaded archive already contains model.tflite that works with TensorFlow Lite, but we re-generate the model in the next sub-section to account for cases where you might re-train the model on your own dataset (with corresponding changes to pipeline.config & checkpoint directory).

Generate TensorFlow Lite Model

First, we invoke export_tflite_graph_tf2.py to generate a TFLite-friendly intermediate SavedModel. This will then be passed to the TensorFlow Lite Converter for generating the final model.

This is similar to what we do for SSD architectures.

%%bash
# Export the intermediate SavedModel that outputs 10 detections & takes in an 
# image of dim 320x320.
# Modify these parameters according to your needs.

python models/research/object_detection/export_tflite_graph_tf2.py \
  --pipeline_config_path=centernet_mobilenetv2_fpn_od/pipeline.config \
  --trained_checkpoint_dir=centernet_mobilenetv2_fpn_od/checkpoint \
  --output_directory=centernet_mobilenetv2_fpn_od/tflite \
  --centernet_include_keypoints=false \
  --max_detections=10 \
  --config_override=" \
    model{ \
      center_net { \
        image_resizer { \
          fixed_shape_resizer { \
            height: 320 \
            width: 320 \
          } \
        } \
      } \
    }"
# Generate TensorFlow Lite model using the converter.
%%bash
tflite_convert --output_file=centernet_mobilenetv2_fpn_od/model.tflite \
  --saved_model_dir=centernet_mobilenetv2_fpn_od/tflite/saved_model

TensorFlow Lite Inference

Use a TensorFlow Lite Interpreter to detect objects in the test image.

%matplotlib inline

# Load the TFLite model and allocate tensors.
model_path = 'centernet_mobilenetv2_fpn_od/model.tflite'
label_map_path = 'centernet_mobilenetv2_fpn_od/label_map.txt'
image_path = 'coco/val2017/000000013729.jpg'

# Initialize TensorFlow Lite Interpreter.
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()

# Label map can be used to figure out what class ID maps to what
# label. `label_map.txt` is human-readable.
category_index = label_map_util.create_category_index_from_labelmap(
    label_map_path)

label_id_offset = 1

image = tf.io.read_file(image_path)
image = tf.compat.v1.image.decode_jpeg(image)
image = tf.expand_dims(image, axis=0)
image_numpy = image.numpy()

input_tensor = tf.convert_to_tensor(image_numpy, dtype=tf.float32)
# Note that CenterNet doesn't require any pre-processing except resizing to the
# input size that the TensorFlow Lite Interpreter was generated with.
input_tensor = tf.image.resize(input_tensor, (320, 320))
boxes, classes, scores, num_detections = detect(interpreter, input_tensor)

vis_image = plot_detections(
    image_numpy[0],
    boxes[0],
    classes[0].astype(np.uint32) + label_id_offset,
    scores[0],
    category_index)
plt.figure(figsize = (30, 20))
plt.imshow(vis_image)

Keypoints

Unlike SSDs, CenterNet also supports COCO Keypoint detection. To be more specific, the 'keypoints' version of CenterNet shown here provides keypoints as a [N, 17, 2]-shaped tensor representing the (normalized) yx-coordinates of 17 COCO human keypoints.

See the detect() function in the Utilities for Inference section to better understand the keypoints output.

Download Model from Detection Zoo

NOTE: Not all CenterNet models from the TF2 Detection Zoo work with TFLite, only the MobileNet-based version does.

# Get mobile-friendly CenterNet for Keypoint detection task.
# See TensorFlow 2 Detection Model Zoo for more details:
# https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md

%%bash
wget http://download.tensorflow.org/models/object_detection/tf2/20210210/centernet_mobilenetv2fpn_512x512_coco17_kpts.tar.gz
tar -xf centernet_mobilenetv2fpn_512x512_coco17_kpts.tar.gz
rm centernet_mobilenetv2fpn_512x512_coco17_kpts.tar.gz*

Generate TensorFlow Lite Model

As before, we leverage export_tflite_graph_tf2.py to generate a TFLite-friendly intermediate SavedModel. This will then be passed to the TFLite converter to generating the final model.

Note that we need to include an additional keypoint_label_map_path parameter for exporting the keypoints outputs.

%%bash
# Export the intermediate SavedModel that outputs 10 detections & takes in an 
# image of dim 320x320.
# Modify these parameters according to your needs.

python models/research/object_detection/export_tflite_graph_tf2.py \
  --pipeline_config_path=centernet_mobilenetv2_fpn_kpts/pipeline.config \
  --trained_checkpoint_dir=centernet_mobilenetv2_fpn_kpts/checkpoint \
  --output_directory=centernet_mobilenetv2_fpn_kpts/tflite \
  --centernet_include_keypoints=true \
  --keypoint_label_map_path=centernet_mobilenetv2_fpn_kpts/label_map.txt \
  --max_detections=10 \
  --config_override=" \
    model{ \
      center_net { \
        image_resizer { \
          fixed_shape_resizer { \
            height: 320 \
            width: 320 \
          } \
        } \
      } \
    }"
# Generate TensorFlow Lite model using the converter.

%%bash
tflite_convert --output_file=centernet_mobilenetv2_fpn_kpts/model.tflite \
  --saved_model_dir=centernet_mobilenetv2_fpn_kpts/tflite/saved_model

TensorFlow Lite Inference

Use a TensorFlow Lite Interpreter to detect people & their keypoints in the test image.

%matplotlib inline

# Load the TFLite model and allocate tensors.
model_path = 'centernet_mobilenetv2_fpn_kpts/model.tflite'
image_path = 'coco/val2017/000000013729.jpg'

# Initialize TensorFlow Lite Interpreter.
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()

# Keypoints are only relevant for people, so we only care about that
# category Id here.
category_index = {1: {'id': 1, 'name': 'person'}}

label_id_offset = 1

image = tf.io.read_file(image_path)
image = tf.compat.v1.image.decode_jpeg(image)
image = tf.expand_dims(image, axis=0)
image_numpy = image.numpy()

input_tensor = tf.convert_to_tensor(image_numpy, dtype=tf.float32)
# Note that CenterNet doesn't require any pre-processing except resizing to
# input size that the TensorFlow Lite Interpreter was generated with.
input_tensor = tf.image.resize(input_tensor, (320, 320))
(boxes, classes, scores, num_detections, kpts, kpts_scores) = detect(
    interpreter, input_tensor, include_keypoint=True)

vis_image = plot_detections(
    image_numpy[0],
    boxes[0],
    classes[0].astype(np.uint32) + label_id_offset,
    scores[0],
    category_index,
    keypoints=kpts[0],
    keypoint_scores=kpts_scores[0])
plt.figure(figsize = (30, 20))
plt.imshow(vis_image)

Running On Mobile

As mentioned earlier, both the above models can be run on mobile phones with TensorFlow Lite. See our inference documentation for general guidelines on platform-specific APIs & leveraging hardware acceleration. Both the object-detection & keypoint-detection versions of CenterNet are compatible with our GPU delegate. We are working on developing quantized versions of this model.

To leverage object-detection in your Android app, the simplest way is to use TFLite's ObjectDetector Task API. It is a high-level API that encapsulates complex but common image processing and post processing logic. Inference can be done in 5 lines of code. It is supported in Java for Android and C++ for native code. We are working on building the Swift API for iOS, as well as the support for the keypoint-detection model.

To use the Task API, the model needs to be packed with TFLite Metadata. This metadata helps the inference code perform the correct pre & post processing as required by the model. Use the following code to create the metadata.

!pip install tflite_support_nightly
from tflite_support.metadata_writers import object_detector
from tflite_support.metadata_writers import writer_utils

ObjectDetectorWriter = object_detector.MetadataWriter

_MODEL_PATH = "centernet_mobilenetv2_fpn_od/model.tflite"
_SAVE_TO_PATH = "centernet_mobilenetv2_fpn_od/model_with_metadata.tflite"
_LABEL_PATH = "centernet_mobilenetv2_fpn_od/tflite_label_map.txt"

# We need to convert Detection API's labelmap into what the Task API needs:
# a txt file with one class name on each line from index 0 to N.
# The first '0' class indicates the background.
# This code assumes COCO detection which has 90 classes, you can write a label
# map file for your model if re-trained.
od_label_map_path = 'centernet_mobilenetv2_fpn_od/label_map.txt'
category_index = label_map_util.create_category_index_from_labelmap(
    label_map_path)
f = open(_LABEL_PATH, 'w')
for class_id in range(1, 91):
  if class_id not in category_index:
    f.write('???\n')
    continue
  name = category_index[class_id]['name']
  f.write(name+'\n')
f.close()

writer = ObjectDetectorWriter.create_for_inference(
    writer_utils.load_file(_MODEL_PATH), input_norm_mean=[0], 
    input_norm_std=[1], label_file_paths=[_LABEL_PATH])
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

Visualize the metadata just created by the following code:

from tflite_support import metadata

displayer = metadata.MetadataDisplayer.with_model_file(_SAVE_TO_PATH)
print("Metadata populated:")
print(displayer.get_metadata_json())
print("=============================")
print("Associated file(s) populated:")
print(displayer.get_packed_associated_file_list())

See more information about object-detection models from our public documentation. The Object Detection example app is a good starting point for integrating that model into your Android and iOS app. You can find examples of using both the TFLite Task Library and TFLite Interpreter API.