tensorflow/compiler/mlir/quantization/stablehlo/python/integration_test/stablehlo_quantizer_odml_oss.ipynb
Copyright 2024 Google LLC.
Licensed under the Apache License, Version 2.0 (the "License");
This example shows a JAX Keras reference model converted into a StableHLO module and via jax2tf, then quantized in the ODML Converter via the StableHLO Quantizer.
Note: This API is experimental and will likely have breakages with other models. Please reach out to [email protected] and we will support your use case.
StableHLO Quantizer is a quantization API to enable ML framework optionality and hardware retargetability.
!pip uninstall tensorflow --yes
!pip3 install tf-nightly
!pip3 install keras-core
import tensorflow as tf
print("TensorFlow version:", tf.__version__)
import os
os.environ['KERAS_BACKEND'] = 'jax'
import jax.numpy as jnp
import numpy as np
import tensorflow as tf
from keras_core.applications import ResNet50
from jax.experimental import jax2tf
input_shape = (1, 224, 224, 3)
jax_callable = jax2tf.convert(
ResNet50(
input_shape=input_shape[1:],
pooling='avg',
).call,
with_gradient=False,
native_serialization=True,
native_serialization_platforms=('cpu',))
tf_module = tf.Module()
tf_module.f = tf.function(
jax_callable,
autograph=False,
input_signature=[
tf.TensorSpec(input_shape, jnp.float32, 'lhs_operand')
],
)
saved_model_dir = '/tmp/saved_model'
tf.saved_model.save(tf_module, saved_model_dir)
def calibration_dataset():
rng = np.random.default_rng(seed=1235)
for _ in range(2):
yield {
'lhs_operand': rng.uniform(low=-1.0, high=1.0, size=input_shape).astype(
np.float32
)
}
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.SELECT_TF_OPS, # enable TensorFlow ops.
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TFL ops.
]
converter.representative_dataset = calibration_dataset
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# Below flag controls whether to use StableHLO Quantizer or TFLite quantizer.
converter.experimental_use_stablehlo_quantizer = True
quantized_model = converter.convert()
with open('/tmp/resnet50_quantized.tflite', 'wb') as f:
f.write(quantized_model)
print(str(os.path.getsize('/tmp/resnet50_quantized.tflite') >> 20) + 'MB')