tensorflow/lite/g3doc/inference_with_metadata/task_library/nl_classifier.md
The Task Library's NLClassifier API classifies input text into different
categories, and is a versatile and configurable API that can handle most text
classification models.
Takes a single string as input, performs classification with the string and outputs <Label, Score> pairs as classification results.
Optional Regex Tokenization available for input text.
Configurable to adapt different classification models.
The following models are guaranteed to be compatible with the NLClassifier
API.
The <a href="../../examples/text_classification/overview">movie review sentiment classification</a> model.
Models with average_word_vec spec created by
TensorFlow Lite Model Maker for text Classification.
Custom models that meet the model compatibility requirements.
See the
Text Classification reference app
for an example of how to use NLClassifier in an Android app.
Copy the .tflite model file to the assets directory of the Android module
where the model will be run. Specify that the file should not be compressed, and
add the TensorFlow Lite library to the module’s build.gradle file:
android {
// Other settings
// Specify tflite file should not be compressed for the app apk
aaptOptions {
noCompress "tflite"
}
}
dependencies {
// Other dependencies
// Import the Task Vision Library dependency (NNAPI is included)
implementation 'org.tensorflow:tensorflow-lite-task-text:0.4.4'
// Import the GPU delegate plugin Library for GPU inference
implementation 'org.tensorflow:tensorflow-lite-gpu-delegate-plugin:0.4.4'
}
Note: starting from version 4.1 of the Android Gradle plugin, .tflite will be added to the noCompress list by default and the aaptOptions above is not needed anymore.
// Initialization, use NLClassifierOptions to configure input and output tensors
NLClassifierOptions options =
NLClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().useGpu().build())
.setInputTensorName(INPUT_TENSOR_NAME)
.setOutputScoreTensorName(OUTPUT_SCORE_TENSOR_NAME)
.build();
NLClassifier classifier =
NLClassifier.createFromFileAndOptions(context, modelFile, options);
// Run inference
List<Category> results = classifier.classify(input);
See the
source code
for more options to configure NLClassifier.
Add the TensorFlowLiteTaskText pod in Podfile
target 'MySwiftAppWithTaskAPI' do
use_frameworks!
pod 'TensorFlowLiteTaskText', '~> 0.4.4'
end
// Initialization
var modelOptions:TFLNLClassifierOptions = TFLNLClassifierOptions()
modelOptions.inputTensorName = inputTensorName
modelOptions.outputScoreTensorName = outputScoreTensorName
let nlClassifier = TFLNLClassifier.nlClassifier(
modelPath: modelPath,
options: modelOptions)
// Run inference
let categories = nlClassifier.classify(text: input)
See the source code for more details.
// Initialization
NLClassifierOptions options;
options.mutable_base_options()->mutable_model_file()->set_file_name(model_path);
std::unique_ptr<NLClassifier> classifier = NLClassifier::CreateFromOptions(options).value();
// Run inference with your input, `input_text`.
std::vector<core::Category> categories = classifier->Classify(input_text);
See the source code for more details.
pip install tflite-support
# Imports
from tflite_support.task import text
# Initialization
classifier = text.NLClassifier.create_from_file(model_path)
# Run inference
text_classification_result = classifier.classify(text)
See the
source code
for more options to configure NLClassifier.
Here is an example of the classification results of the movie review model.
Input: "What a waste of my time."
Output:
category[0]: 'Negative' : '0.81313'
category[1]: 'Positive' : '0.18687'
Try out the simple CLI demo tool for NLClassifier with your own model and test data.
Depending on the use case, the NLClassifier API can load a TFLite model with
or without TFLite Model Metadata. See examples
of creating metadata for natural language classifiers using the
TensorFlow Lite Metadata Writer API.
The compatible models should meet the following requirements:
Input tensor: (kTfLiteString/kTfLiteInt32)
RegexTokenizer needs to be set up in
the input tensor's
Metadata.Output score tensor: (kTfLiteUInt8/kTfLiteInt8/kTfLiteInt16/kTfLiteFloat32/kTfLiteFloat64)
Mandatory output tensor for the score of each category classified.
If type is one of the Int types, dequantize it to double/float to corresponding platforms
Can have an optional associated file in the output tensor's corresponding Metadata for category labels, the file should be a plain text file with one label per line, and the number of labels should match the number of categories as the model outputs. See the example label file.
Output label tensor: (kTfLiteString/kTfLiteInt32)
Optional output tensor for the label for each category, should be of the same length as the output score tensor. If this tensor is not present, the API uses score indices as classnames.
Will be ignored if the associated label file is present in output score tensor's Metadata.