Back to Tensorflow

@title Licensed under the Apache License, Version 2.0 (the "License");

tensorflow/lite/g3doc/examples/keras/keras_jax_backend_to_tfl.ipynb

2.21.03.4 KB
Original Source
Copyright 2024 The TensorFlow Authors.
# @title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Converting Keras to TFLite (via the JAX backend)

<table class="tfo-notebook-buttons" align="left"> <td> <a target="_blank" href="https://www.tensorflow.org/lite/examples/keras/keras_jax_backend_to_tfl">View on TensorFlow.org</a> </td> <td> <a target="_blank" href="https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/keras/keras_jax_backend_to_tfl.ipynb">Run in Google Colab</a> </td> <td> <a target="_blank" href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/keras/keras_jax_backend_to_tfl.ipynb">View source on GitHub</a> </td> <td> <a href="https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/lite/g3doc/examples/keras/keras_jax_backend_to_tfl.ipynb">Download notebook</a> </td> </table>
import os

os.environ["KERAS_BACKEND"] = "jax"

Setup

import keras
import tensorflow as tf
import numpy as np

Get the test image data

from PIL import Image
import requests

url = "https://storage.googleapis.com/download.tensorflow.org/example_images/astrid_l_shaped.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image = image.resize((224, 224))
input_image = np.array(image)
input_image = np.expand_dims(input_image, axis=0)

Instatiate a Resnet50 model from the Keras models library

jax_model = keras.applications.resnet.ResNet50(include_top=True, weights="imagenet")

Run the keras JAX model with the test input

input_data = keras.applications.resnet50.preprocess_input(input_image)
jax_model_output = jax_model(input_data)

decoded_preds = keras.applications.resnet.decode_predictions(jax_model_output, top=1)[
    0
][0]
print("Predicted class:", decoded_preds[1])

Save the Keras JAX model

saved_model_dir = "resnet50_saved_model"
jax_model.export(saved_model_dir)

Convert to a TFLite model file

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()

Run using TFLite Runtime

interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()[0]
interpreter.set_tensor(input_details["index"], input_data)
interpreter.invoke()

output_details = interpreter.get_output_details()
output_data = interpreter.get_tensor(output_details[0]["index"])

tfl_predicted_class_idx = keras.applications.resnet.decode_predictions(
    output_data, top=1
)[0][0]
print("Predicted class:", tfl_predicted_class_idx[1])