tensorflow/lite/g3doc/inference_with_metadata/task_library/image_classifier.md
Image classification is a common use of machine learning to identify what an image represents. For example, we might want to know what type of animal appears in a given picture. The task of predicting what an image represents is called image classification. An image classifier is trained to recognize various classes of images. For example, a model might be trained to recognize photos representing three different types of animals: rabbits, hamsters, and dogs. See the image classification overview for more information about image classifiers.
Use the Task Library ImageClassifier API to deploy your custom image
classifiers or pretrained ones into your mobile apps.
Input image processing, including rotation, resizing, and color space conversion.
Region of interest of the input image.
Label map locale.
Score threshold to filter results.
Top-k classification results.
Label allowlist and denylist.
The following models are guaranteed to be compatible with the ImageClassifier
API.
Models created by TensorFlow Lite Model Maker for Image Classification.
The pretrained image classification models on TensorFlow Hub.
Models created by AutoML Vision Edge Image Classification.
Custom models that meet the model compatibility requirements.
See the
Image Classification reference app
for an example of how to use ImageClassifier in an Android app.
Copy the .tflite model file to the assets directory of the Android module
where the model will be run. Specify that the file should not be compressed, and
add the TensorFlow Lite library to the module’s build.gradle file:
android {
// Other settings
// Specify tflite file should not be compressed for the app apk
aaptOptions {
noCompress "tflite"
}
}
dependencies {
// Other dependencies
// Import the Task Vision Library dependency (NNAPI is included)
implementation 'org.tensorflow:tensorflow-lite-task-vision'
// Import the GPU delegate plugin Library for GPU inference
implementation 'org.tensorflow:tensorflow-lite-gpu-delegate-plugin'
}
// Initialization
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().useGpu().build())
.setMaxResults(1)
.build();
ImageClassifier imageClassifier =
ImageClassifier.createFromFileAndOptions(
context, modelFile, options);
// Run inference
List<Classifications> results = imageClassifier.classify(image);
See the
source code and javadoc
for more options to configure ImageClassifier.
The Task Library supports installation using CocoaPods. Make sure that CocoaPods is installed on your system. Please see the CocoaPods installation guide for instructions.
Please see the CocoaPods guide for details on adding pods to an Xcode project.
Add the TensorFlowLiteTaskVision pod in the Podfile.
target 'MyAppWithTaskAPI' do
use_frameworks!
pod 'TensorFlowLiteTaskVision'
end
Make sure that the .tflite model you will be using for inference is present in
your app bundle.
// Imports
import TensorFlowLiteTaskVision
// Initialization
guard let modelPath = Bundle.main.path(forResource: "birds_V1",
ofType: "tflite") else { return }
let options = ImageClassifierOptions(modelPath: modelPath)
// Configure any additional options:
// options.classificationOptions.maxResults = 3
let classifier = try ImageClassifier.classifier(options: options)
// Convert the input image to MLImage.
// There are other sources for MLImage. For more details, please see:
// https://developers.google.com/ml-kit/reference/ios/mlimage/api/reference/Classes/GMLImage
guard let image = UIImage (named: "sparrow.jpg"), let mlImage = MLImage(image: image) else { return }
// Run inference
let classificationResults = try classifier.classify(mlImage: mlImage)
// Imports
#import <TensorFlowLiteTaskVision/TensorFlowLiteTaskVision.h>
// Initialization
NSString *modelPath = [[NSBundle mainBundle] pathForResource:@"birds_V1" ofType:@"tflite"];
TFLImageClassifierOptions *options =
[[TFLImageClassifierOptions alloc] initWithModelPath:modelPath];
// Configure any additional options:
// options.classificationOptions.maxResults = 3;
TFLImageClassifier *classifier = [TFLImageClassifier imageClassifierWithOptions:options
error:nil];
// Convert the input image to MLImage.
UIImage *image = [UIImage imageNamed:@"sparrow.jpg"];
// There are other sources for GMLImage. For more details, please see:
// https://developers.google.com/ml-kit/reference/ios/mlimage/api/reference/Classes/GMLImage
GMLImage *gmlImage = [[GMLImage alloc] initWithImage:image];
// Run inference
TFLClassificationResult *classificationResult =
[classifier classifyWithGMLImage:gmlImage error:nil];
See the
source code
for more options to configure TFLImageClassifier.
pip install tflite-support
# Imports
from tflite_support.task import vision
from tflite_support.task import core
from tflite_support.task import processor
# Initialization
base_options = core.BaseOptions(file_name=model_path)
classification_options = processor.ClassificationOptions(max_results=2)
options = vision.ImageClassifierOptions(base_options=base_options, classification_options=classification_options)
classifier = vision.ImageClassifier.create_from_options(options)
# Alternatively, you can create an image classifier in the following manner:
# classifier = vision.ImageClassifier.create_from_file(model_path)
# Run inference
image = vision.TensorImage.create_from_file(image_path)
classification_result = classifier.classify(image)
See the
source code
for more options to configure ImageClassifier.
// Initialization
ImageClassifierOptions options;
options.mutable_base_options()->mutable_model_file()->set_file_name(model_path);
std::unique_ptr<ImageClassifier> image_classifier = ImageClassifier::CreateFromOptions(options).value();
// Create input frame_buffer from your inputs, `image_data` and `image_dimension`.
// See more information here: tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h
std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
image_data, image_dimension);
// Run inference
const ClassificationResult result = image_classifier->Classify(*frame_buffer).value();
See the
source code
for more options to configure ImageClassifier.
Here is an example of the classification results of a bird classifier.
Results:
Rank #0:
index : 671
score : 0.91406
class name : /m/01bwb9
display name: Passer domesticus
Rank #1:
index : 670
score : 0.00391
class name : /m/01bwbt
display name: Passer montanus
Rank #2:
index : 495
score : 0.00391
class name : /m/0bwm6m
display name: Passer italiae
Try out the simple CLI demo tool for ImageClassifier with your own model and test data.
The ImageClassifier API expects a TFLite model with mandatory
TFLite Model Metadata.
See examples of creating metadata for image classifiers using the
TensorFlow Lite Metadata Writer API.
The compatible image classifier models should meet the following requirements:
Input image tensor (kTfLiteUInt8/kTfLiteFloat32)
[batch x height x width x channels].batch is required to be 1).channels is required to be 3).Output score tensor (kTfLiteUInt8/kTfLiteFloat32)
N classes and either 2 or 4 dimensions, i.e. [1 x N] or [1 x 1 x 1 x N]label field
(named as class_name in C++) of the results. The display_name field
is filled from the AssociatedFile (if any) whose locale matches the
display_names_locale field of the ImageClassifierOptions used at
creation time ("en" by default, i.e. English). If none of these are
available, only the index field of the results will be filled.