Back to Models

MobileBERT QAT Tutorial

official/projects/qat/nlp/docs/MobileBERT_QAT_tutorial.ipynb

2.20.03.8 KB
Original Source

MobileBERT QAT Tutorial

This notebook provides a basic example code to build, run, and fine-tune MobileBERT with QAT toolkit.

Pretrained models downloaded from the TensorFlow Hub and the TensorFlow Model Garden, which are both trained on SQuAD dateset for Q&A task. You will run inference the models with dummy inputs.

Setup

python
# Install packages

# tf-models-official is the stable Model Garden package
# tf-models-nightly includes latest changes
!pip install -q tf-models-nightly
python
# Run imports
import os

import numpy as np
import tensorflow as tf
import tensorflow_hub as hub

Launch QAT Training

Follow the training guideline to start QAT training using the pretrained checkpoint.

Running model from TFHub

Running QAT trained MobileBERT model from tfhub. Note that it contains Fake-quant op and all ops are float32. It becomes actual int8 op when you convert them to TFLite using TFLite converter.

python
loaded_obj = hub.load("https://tfhub.dev/google/qat/nlp/mobilebert_xs_qat/1")
serving_model = loaded_obj.signatures['serving_default']

# Dummy inputs
input_type_ids = tf.zeros(shape=[1, 384], dtype=tf.int32)
input_word_ids = tf.zeros(shape=[1, 384], dtype=tf.int32)
input_mask = tf.zeros(shape=[1, 384], dtype=tf.int32)

bert_inputs = dict(
    input_type_ids=input_type_ids, input_word_ids=input_word_ids, input_mask=input_mask)

bert_outputs = serving_model(**bert_inputs)

start_logits = bert_outputs["start_logits"]
end_logits =  bert_outputs["end_logits"]

print(start_logits.shape)
print(end_logits.shape)

Running TFLite Model Inference

Running inference with trained quantized TFLite model with dummy dataset. We assume that data is already converted to integer from an input string using vocabulary.

python
# First download the TFLite model.
! curl https://storage.googleapis.com/tf_model_garden/nlp/qat/mobilebert/model_qat.tflite --output model_qat.tflite
python
def get_dequantized_tensor(interpreter, output_detail):
  if ('quantization' not in output_detail or
      np.dtype(output_detail['dtype']) == np.dtype(np.float32)):
    return interpreter.get_tensor(output_detail['index'])
  output_scale, output_zero_point = output_detail['quantization']
  return (np.array(interpreter.get_tensor(output_detail['index']), dtype=np.float32) - output_zero_point) * output_scale

def run_tflite(interpreter, input_word_ids, input_mask, input_type_ids):
  input_word_ids_index, input_mask_index, input_type_ids_index = [
      detail['index'] for detail in interpreter.get_input_details()]
  interpreter.set_tensor(input_word_ids_index, input_word_ids)
  interpreter.set_tensor(input_mask_index, input_mask)
  interpreter.set_tensor(input_type_ids_index, input_type_ids)
  interpreter.invoke()

  start_logits_detail, end_logits_detail = interpreter.get_output_details()

  return get_dequantized_tensor(interpreter, start_logits_detail), get_dequantized_tensor(interpreter, end_logits_detail)
python
tflite_file = 'model_qat.tflite'
with open(tflite_file, 'rb') as fp:
  tflite_model = fp.read()

interpreter = tf.lite.Interpreter(
    model_content=tflite_model,
    experimental_preserve_all_tensors=True)
interpreter.allocate_tensors()
python
# Dummy inputs
input_type_ids = np.zeros(shape=[1, 384], dtype=np.int32)
input_word_ids = np.zeros(shape=[1, 384], dtype=np.int32)
input_mask = np.zeros(shape=[1, 384], dtype=np.int32)

start_logits, end_logits = run_tflite(interpreter, input_type_ids, input_word_ids, input_mask)

print(start_logits.shape)
print(end_logits.shape)