tensorflow/lite/g3doc/examples/jax_conversion/jax_to_tflite_resnet50.ipynb
# @title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
orbax-export API is used to export the JAX Module to a TF Saved Model, along with image pre/post-processing functions.!pip install orbax-export
!pip install tf-nightly
!pip install --upgrade jax jaxlib
!pip install transformers flax
from PIL import Image
import jax
import requests
url = "https://storage.googleapis.com/download.tensorflow.org/example_images/astrid_l_shaped.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image
from transformers import ConvNextImageProcessor, FlaxResNetForImageClassification
image_processor = ConvNextImageProcessor.from_pretrained("microsoft/resnet-50")
model = FlaxResNetForImageClassification.from_pretrained("microsoft/resnet-50")
inputs = image_processor(images=image, return_tensors="np")
outputs = model(**inputs)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = jax.numpy.argmax(logits, axis=-1)
print("Predicted class:", model.config.id2label[predicted_class_idx.item()])
Wrapper is needed in order to comply with TFLite accepts inputs. TFLite accets a tensor or a tuple-of-tensors.
import flax.linen as nn
from transformers import FlaxResNetForImageClassification
class Resnet50Wrapper(nn.Module):
pretrained_model_name: str = "microsoft/resnet-50" # Pre-trained model name
def setup(self):
# Initialize the pre-trained ResNet50 model
self.model = FlaxResNetForImageClassification.from_pretrained(
self.pretrained_model_name
)
def __call__(self, inputs):
# Process input images through the ResNet50 model
outputs = self.model(pixel_values=inputs)
# Return logits or directly apply softmax for probabilities (optional)
return outputs.logits
This essentialliy implements the underlying logic of the ConvNextImageProcessor class in huggingface transformers:
image_processor = ConvNextImageProcessor.from_pretrained("microsoft/resnet-50") inputs = image_processor(images=image, return_tensors="np")
This utility can be reused later during orbax-export for tf_preprocessing.
Note: We can perfectly use the result of ConvNextImageProcessor to run a TFLite model. But this example would like to showcase how orbax-export helps handle input/output pre/post-processing.
import tensorflow as tf
import numpy as np
def resnet_image_processor(image_tensor):
# 1. Resize and Cast to Float32
image_resized = tf.image.resize(
image_tensor, (224, 224), method=tf.image.ResizeMethod.BILINEAR
)
image_float = tf.cast(image_resized, tf.float32)
# 2. Normalize (Using TensorFlow Constants)
mean = tf.constant([0.485, 0.456, 0.406])
std = tf.constant([0.229, 0.224, 0.225])
image_normalized = (image_float / 255.0 - mean) / std
# 3. Transpose for Channel-First Format
image_transposed = tf.transpose(image_normalized, perm=[2, 0, 1])
# 4. Add Batch Dimension
return tf.expand_dims(image_transposed, axis=0)
# Initialize the JAX Model
jax_model = Resnet50Wrapper()
# Convert the raw image values to RGB tensor
raw_image_tensor = tf.convert_to_tensor(np.array(image, dtype=np.float32))
# Appy the above TF imape preprocessing to get an input tensor supported by Resnet50
input_tensor = resnet_image_processor(raw_image_tensor)
# Run the JAX model
jax_logits = jax_model.apply({}, input_tensor.numpy())
jax_predicted_class_idx = jax.numpy.argmax(jax_logits, axis=-1)
print("Predicted class:", model.config.id2label[jax_predicted_class_idx.item()])
raw_image_tensor.shape
from orbax.export import ExportManager, JaxModule, ServingConfig
# Wrap the model params and function into a JaxModule.
jax_module = JaxModule({}, jax_model.apply, trainable=False)
# Specify the serving configuration and export the model.
serving_config = ServingConfig(
"serving_default",
input_signature=[tf.TensorSpec([480, 640, 3], tf.float32, name="inputs")],
tf_preprocessor=resnet_image_processor,
tf_postprocessor=lambda x: tf.argmax(x, axis=-1),
)
export_manager = ExportManager(jax_module, [serving_config])
saved_model_dir = "resnet50_saved_model"
export_manager.save(saved_model_dir)
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
raw_image_tensordef run_tflite_model(tflite_model_content, input_tensor):
interpreter = tf.lite.Interpreter(model_content=tflite_model_content)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()[0]
interpreter.set_tensor(input_details["index"], input_tensor)
interpreter.invoke()
output_details = interpreter.get_output_details()
return interpreter.get_tensor(output_details[0]["index"])
output_data = run_tflite_model(tflite_model, raw_image_tensor)
print("Predicted class:", model.config.id2label[output_data[0]])
saved_model_dir_2 = "resnet50_saved_model_1"
tf.saved_model.save(
jax_module,
saved_model_dir_2,
signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
tf.TensorSpec([1, 3, 224, 224], tf.float32, name="inputs")
),
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter_1 = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir_2)
tflite_model_1 = converter_1.convert()
output_data_1 = run_tflite_model(tflite_model_1, input_tensor)
tfl_predicted_class_idx_1 = tf.argmax(output_data_1, axis=-1).numpy()
print("Predicted class:", model.config.id2label[tfl_predicted_class_idx_1[0]])
converter_2 = tf.lite.TFLiteConverter.from_concrete_functions(
[
jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
tf.TensorSpec([1, 3, 224, 224], tf.float32, name="inputs")
)
]
)
tflite_model_2 = converter_2.convert()
output_data_2 = run_tflite_model(tflite_model_2, input_tensor)
tfl_predicted_class_idx_2 = tf.argmax(output_data_2, axis=-1).numpy()
print("Predicted class:", model.config.id2label[tfl_predicted_class_idx_2[0]])