official/projects/qat/nlp/docs/MobileBERT_QAT_tutorial.ipynb
This notebook provides a basic example code to build, run, and fine-tune MobileBERT with QAT toolkit.
Pretrained models downloaded from the TensorFlow Hub and the TensorFlow Model Garden, which are both trained on SQuAD dateset for Q&A task. You will run inference the models with dummy inputs.
# Install packages
# tf-models-official is the stable Model Garden package
# tf-models-nightly includes latest changes
!pip install -q tf-models-nightly
# Run imports
import os
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
Follow the training guideline to start QAT training using the pretrained checkpoint.
Running QAT trained MobileBERT model from tfhub. Note that it contains Fake-quant op and all ops are float32. It becomes actual int8 op when you convert them to TFLite using TFLite converter.
loaded_obj = hub.load("https://tfhub.dev/google/qat/nlp/mobilebert_xs_qat/1")
serving_model = loaded_obj.signatures['serving_default']
# Dummy inputs
input_type_ids = tf.zeros(shape=[1, 384], dtype=tf.int32)
input_word_ids = tf.zeros(shape=[1, 384], dtype=tf.int32)
input_mask = tf.zeros(shape=[1, 384], dtype=tf.int32)
bert_inputs = dict(
input_type_ids=input_type_ids, input_word_ids=input_word_ids, input_mask=input_mask)
bert_outputs = serving_model(**bert_inputs)
start_logits = bert_outputs["start_logits"]
end_logits = bert_outputs["end_logits"]
print(start_logits.shape)
print(end_logits.shape)
Running inference with trained quantized TFLite model with dummy dataset. We assume that data is already converted to integer from an input string using vocabulary.
# First download the TFLite model.
! curl https://storage.googleapis.com/tf_model_garden/nlp/qat/mobilebert/model_qat.tflite --output model_qat.tflite
def get_dequantized_tensor(interpreter, output_detail):
if ('quantization' not in output_detail or
np.dtype(output_detail['dtype']) == np.dtype(np.float32)):
return interpreter.get_tensor(output_detail['index'])
output_scale, output_zero_point = output_detail['quantization']
return (np.array(interpreter.get_tensor(output_detail['index']), dtype=np.float32) - output_zero_point) * output_scale
def run_tflite(interpreter, input_word_ids, input_mask, input_type_ids):
input_word_ids_index, input_mask_index, input_type_ids_index = [
detail['index'] for detail in interpreter.get_input_details()]
interpreter.set_tensor(input_word_ids_index, input_word_ids)
interpreter.set_tensor(input_mask_index, input_mask)
interpreter.set_tensor(input_type_ids_index, input_type_ids)
interpreter.invoke()
start_logits_detail, end_logits_detail = interpreter.get_output_details()
return get_dequantized_tensor(interpreter, start_logits_detail), get_dequantized_tensor(interpreter, end_logits_detail)
tflite_file = 'model_qat.tflite'
with open(tflite_file, 'rb') as fp:
tflite_model = fp.read()
interpreter = tf.lite.Interpreter(
model_content=tflite_model,
experimental_preserve_all_tensors=True)
interpreter.allocate_tensors()
# Dummy inputs
input_type_ids = np.zeros(shape=[1, 384], dtype=np.int32)
input_word_ids = np.zeros(shape=[1, 384], dtype=np.int32)
input_mask = np.zeros(shape=[1, 384], dtype=np.int32)
start_logits, end_logits = run_tflite(interpreter, input_type_ids, input_word_ids, input_mask)
print(start_logits.shape)
print(end_logits.shape)