tensorflow/lite/g3doc/examples/keras/keras_jax_backend_to_tfl.ipynb
# @title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
import tensorflow as tf
import numpy as np
from PIL import Image
import requests
url = "https://storage.googleapis.com/download.tensorflow.org/example_images/astrid_l_shaped.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image = image.resize((224, 224))
input_image = np.array(image)
input_image = np.expand_dims(input_image, axis=0)
jax_model = keras.applications.resnet.ResNet50(include_top=True, weights="imagenet")
input_data = keras.applications.resnet50.preprocess_input(input_image)
jax_model_output = jax_model(input_data)
decoded_preds = keras.applications.resnet.decode_predictions(jax_model_output, top=1)[
0
][0]
print("Predicted class:", decoded_preds[1])
saved_model_dir = "resnet50_saved_model"
jax_model.export(saved_model_dir)
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()[0]
interpreter.set_tensor(input_details["index"], input_data)
interpreter.invoke()
output_details = interpreter.get_output_details()
output_data = interpreter.get_tensor(output_details[0]["index"])
tfl_predicted_class_idx = keras.applications.resnet.decode_predictions(
output_data, top=1
)[0][0]
print("Predicted class:", tfl_predicted_class_idx[1])