Back to Ente

Prepping MobileCLIP model for use in Ente

infra/ml/playground/CLIP/mobileclip_onnx.ipynb

2.0.3426.1 KB
Original Source

Prepping MobileCLIP model for use in Ente

Paper | Github

Setting up Pytorch weights and source code

python
# !mkdir mobileclip_repo
# %cd mobileclip_repo
# !git clone https://github.com/apple/ml-mobileclip.git
# %cd ml-mobileclip
python
%cd mobileclip_repo/ml-mobileclip/
python
# !source get_pretrained_models.sh   # Files will be downloaded to `checkpoints` directory.
# %cd ../..

Imports

python
!uv pip install clip-benchmark>=1.4.0 datasets>=2.8.0 open-clip-torch>=2.20.0 timm>=0.9.5
python
import torch
import torch.onnx
import torchvision
import torch.nn as nn
from PIL import Image
import mobileclip
import numpy as np
from numpy.linalg import norm
import onnx
import onnxruntime as ort
print(ort.__version__)
python
model, _, preprocess = mobileclip.create_model_and_transforms('mobileclip_s2', pretrained='checkpoints/mobileclip_s2.pt')
og_model = model
model.eval()
og_model.eval()
tokenizer = mobileclip.get_tokenizer('mobileclip_s2')

image = preprocess(Image.open("docs/fig_accuracy_latency.png").convert('RGB')).unsqueeze(0)
text = tokenizer(["Hello World!", "a diagram", "a dog", "a cat"])

with torch.no_grad(), torch.cuda.amp.autocast():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

print("Label probs:", text_probs)
python
%cd ../..
python
# !rm -rf mobileclip_repo
python
tokenizer(["This is a tokenized string"])
python
text_input = tokenizer(["Hello World! This is a super duper long piece of text of at least 77 tokens, purely to make sure that indeed this is a good input without any zeros that the exporter might somehow confuse with a boolean. Apparently we're still not at 77 tokens, so I just keep on monkey typing this story in the hope that someday I have a fully tokenized string of text that is longer than the required 77 tokens. Thank you for coming to my TED talk."])
text_emb = model.encode_text(text_input)[0].detach().numpy()
text_emb /= norm(text_emb)
python
preprocess
python
from PIL import Image
python
image_singapore = Image.open("../data/singapore.jpg").convert('RGBA')
image_input = preprocess(image_singapore).unsqueeze(0)
print(image_input.detach().numpy().shape)
print(1*3*256*256)
python
image_emb = model(image_input[:,:3,:,:])[0][0].detach().numpy()
print(image_emb.shape)
print(norm(image_emb))
image_emb[0:5]
python
image_singapore_onnx = np.array(image_singapore)
print(image_singapore_onnx.shape)
print(image_singapore_onnx.dtype)

Export to ONNX

python
onnx_opset = 18  # use opset 18 for Resize to antialias

Image model

python
class EncodeImageWrapper(nn.Module):
    def __init__(self, original_model):
        super(EncodeImageWrapper, self).__init__()
        self.original_model = original_model

    def forward(self, input):
        return self.original_model.encode_image(input)
python
image_model_wrapper = EncodeImageWrapper(model)
image_model_wrapper.eval()
image_model_wrapper.original_model.eval()
clip_image_onnx_export_path = "onnx_models/mobileclip_s2_image_float32.onnx"
torch.onnx.export(image_model_wrapper, image, clip_image_onnx_export_path, opset_version=onnx_opset, do_constant_folding=True, input_names=["input"], output_names=["output"])
python
mobileclip_image_onnx = onnx.load(clip_image_onnx_export_path)
onnx.checker.check_model(mobileclip_image_onnx)

Text model

python
class EncodeTextWrapper(nn.Module):
    def __init__(self, original_model):
        super(EncodeTextWrapper, self).__init__()
        self.original_model = original_model

    def forward(self, input):
        return self.original_model.encode_text(input)
python
text_model_wrapper = EncodeTextWrapper(model)
text_model_wrapper.eval()
text_model_wrapper.original_model.eval()
clip_text_onnx_export_path = "onnx_models/mobileclip_s2_text_int64.onnx"
torch.onnx.export(text_model_wrapper, text_input, clip_text_onnx_export_path, opset_version=onnx_opset, do_constant_folding=True, input_names=['input'], output_names=['output'])

Altering ONNX models

Image model

Change input name to og_input so we can reserve input for altered model that includes preprocessing

python
og_input = onnx.helper.make_tensor_value_info(
    name="og_input",
    elem_type=onnx.TensorProto.FLOAT,
    shape=[1, 3, 256, 256],  
)

# Update the input names in the rest of the model
for node in mobileclip_image_onnx.graph.node:
    for i, input_name in enumerate(node.input):
        if input_name == "input":
            node.input[i] = "og_input"

graph = onnx.helper.make_graph(
    nodes=mobileclip_image_onnx.graph.node,
    name=mobileclip_image_onnx.graph.name,
    inputs=[og_input],
    outputs=mobileclip_image_onnx.graph.output,
    initializer=mobileclip_image_onnx.graph.initializer,
)
mobileclip_image_onnx = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", onnx_opset)])
onnx.save_model(mobileclip_image_onnx, clip_image_onnx_export_path)

Add preprocessing to the model

python
from onnxruntime_extensions.tools.pre_post_processing import PrePostProcessor, create_named_value, Resize, ImageBytesToFloat, Unsqueeze, CenterCrop, Debug, ChannelsLastToChannelsFirst
python
inputs = [create_named_value("input_to_process", onnx.TensorProto.UINT8, ["H", "W", "C"])]

pipeline = PrePostProcessor(inputs, onnx_opset)

pipeline.add_pre_processing(
    [
        Resize(256),  
        CenterCrop(256, 256),  # Crop to 256x256. NOTE: Currently only HWC input is handled.
        ChannelsLastToChannelsFirst(),  # Convert to CHW
        # Debug(),
        ImageBytesToFloat(),  # Convert to float in range 0..1 by dividing uint8 values by 255
        # Debug(),
        Unsqueeze([0]),  # add batch, CHW --> 1CHW
        # Debug(),
    ]
)

clip_image_with_preprocessing = pipeline.run(mobileclip_image_onnx)

onnx.checker.check_model(clip_image_with_preprocessing)
clip_image_onnx_rgb_path = f"onnx_models/mobileclip_s2_image_opset{onnx_opset}_rgb.onnx"
new_model_path = clip_image_onnx_rgb_path
onnx.save_model(clip_image_with_preprocessing, new_model_path)

Add a slice node so that the model can take raw RGBA data as input (as well as standard RGB)

python
onnx_model = clip_image_with_preprocessing

# Create a new input with flexible channel dimension
new_input = onnx.helper.make_tensor_value_info(
    name="input",
    elem_type=onnx.TensorProto.UINT8,
    shape=["H", "W", "C"],  
)

# Create constant tensors for starts, ends, and axes
starts_tensor = onnx.helper.make_tensor(
    name="starts",
    data_type=onnx.TensorProto.INT64,
    dims=[1],
    vals=np.array([0], dtype=np.int64)
)
ends_tensor = onnx.helper.make_tensor(
    name="ends",
    data_type=onnx.TensorProto.INT64,
    dims=[1],
    vals=np.array([3], dtype=np.int64)
)
axes_tensor = onnx.helper.make_tensor(
    name="axes",
    data_type=onnx.TensorProto.INT64,
    dims=[1],
    vals=np.array([2], dtype=np.int64)
)
new_initializers = [starts_tensor, ends_tensor, axes_tensor] + list(onnx_model.graph.initializer)
slice_node = onnx.helper.make_node(
    "Slice",
    inputs=["input", "starts", "ends", "axes"],
    outputs=["sliced_input"],
    name="slice_rgba_input_node"
)


# Add the new input and Slice node to the graph
graph = onnx.helper.make_graph(
    [slice_node] + list(onnx_model.graph.node),  # Prepend Slice node to existing nodes
    onnx_model.graph.name,
    [new_input],
    list(onnx_model.graph.output),
    initializer=new_initializers,
    value_info=onnx_model.graph.value_info,
)

# Create the new model
mobileclip_image_onnx_rgba = onnx.helper.make_model(
    graph,
    opset_imports=[onnx.helper.make_opsetid("", onnx_opset)]
)


# Update the input names in the rest of the model
for node in mobileclip_image_onnx_rgba.graph.node:
    for i, input_name in enumerate(node.input):
        if input_name == "input_to_process":
            node.input[i] = "sliced_input"

# Save the new model
onnx.checker.check_model(mobileclip_image_onnx_rgba)
clip_image_onnx_rgba_path = f"onnx_models/mobileclip_s2_image_opset{onnx_opset}_rgba.onnx"
onnx.save(mobileclip_image_onnx_rgba, clip_image_onnx_rgba_path)

Optimize the model

python
clip_image_sim_path = f"onnx_models/mobileclip_s2_image_opset{onnx_opset}_rgba_sim.onnx"
python
!onnxsim {clip_image_onnx_rgba_path} {clip_image_sim_path}

Optimize the graph

python
image_opt_sess_options = ort.SessionOptions()

image_opt_sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
image_opt_sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC

clip_image_opt_path = f"onnx_models/mobileclip_s2_image_opset{onnx_opset}_rgba_opt.onnx"
image_opt_sess_options.optimized_model_filepath = clip_image_opt_path

opt_image_session = ort.InferenceSession(clip_image_sim_path, image_opt_sess_options)

Add metadata to the model

python
clip_image_opt = onnx.load(clip_image_opt_path)
clip_image_opt.producer_name = "EnteMobileCLIPImageEncoder"
clip_image_opt.doc_string = "MobileCLIP S2 Image Encoder with built-in preprocessing. Accepts both RGB and RGBA raw bytes input (uint8) in HWC format."
clip_image_opt.graph.doc_string = ""
clip_image_opt.graph.name = "SliceRGB+Resize+CenterCrop+ToFloat+Unsqueeze+MobileCLIP_S2_ImageEncoder"
onnx.save(clip_image_opt, clip_image_opt_path)

Test the model

python
ort_session = ort.InferenceSession(clip_image_opt_path)
onnx_emb = ort_session.run(None, {"input": image_singapore_onnx})[0][0]
onnx_emb /= norm(onnx_emb)
np.dot(image_emb, onnx_emb)
python
!rm {clip_image_onnx_export_path}
!rm {clip_image_onnx_rgb_path}
!rm {clip_image_onnx_rgba_path}
!rm {clip_image_sim_path}

Text model

Make sure the model can use int32 as input

python
mobileclip_text_onxx = onnx.load(clip_text_onnx_export_path)

for tensor in mobileclip_text_onxx.graph.input:
    if tensor.name == "input":
        tensor.type.tensor_type.elem_type = onnx.TensorProto.INT32
        break

# Save the modified model
clip_text_onnx_int32_path = "onnx_models/mobileclip_s2_text_int32.onnx"
onnx.save(mobileclip_text_onxx, clip_text_onnx_int32_path)

Simplify the model

python
clip_text_sim_path = f"onnx_models/mobileclip_s2_text_opset{onnx_opset}_int32_sim.onnx"
python
!onnxsim {clip_text_onnx_int32_path} {clip_text_sim_path}

Apply basic offline graph optimizations. Only do the basic optimizations offline, the extended and layout optimizations should be done online depending on execution provider and hardware.

python
text_opt_sess_options = ort.SessionOptions()

text_opt_sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
text_opt_sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC

clip_text_opt_path = f"onnx_models/mobileclip_s2_text_opset{onnx_opset}_int32_opt.onnx"
text_opt_sess_options.optimized_model_filepath = clip_text_opt_path

opt_text_session = ort.InferenceSession(clip_text_sim_path, text_opt_sess_options)

Add metadata to the model

python
clip_text_opt = onnx.load(clip_text_opt_path)
clip_text_opt.producer_name = "EnteMobileCLIPTextEncoder"
clip_text_opt.doc_string = "MobileCLIP S2 Text Encoder. Accepts an integer array (int32) of length 77. Longer arrays will be truncated."
clip_text_opt.graph.doc_string = ""
clip_text_opt.graph.name = "MobileCLIP_S2_TextEncoder"
onnx.save(clip_text_opt, clip_text_opt_path)

Test the model

python
mobileclip_text_ort_sess = ort.InferenceSession(clip_text_opt_path)
text_onnx_emb = mobileclip_text_ort_sess.run(["output"], {"input": text_input.numpy().astype("int32")})[0][0]
text_onnx_emb /= norm(text_onnx_emb)
np.dot(text_emb, text_onnx_emb)
python
!rm {clip_text_onnx_export_path}
!rm {clip_text_onnx_int32_path}
!rm {clip_text_sim_path}

Quantize text model

https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html

Quantization pre-processing (not to confuse with normal pre-processing)

python
from onnxruntime.quantization import quant_pre_process
python
clip_text_quantized_preprocessed_path = "onnx_models/mobileclip_s2_text_quant_preprocessed.onnx"
quant_pre_process(clip_text_opt_path, clip_text_quantized_preprocessed_path)

Dynamic quantization

python
from onnxruntime.quantization import quantize_dynamic, quantize_static, QuantType
python
node_names = []
matmul_nodes_names = []
for node in clip_text_opt.graph.node:
    node_names.append(node.name)
    if node.op_type == "MatMul" and node.name != "/text_encoder/transformer.0/pre_norm_ffn/pre_norm_ffn.4/MatMul":
        matmul_nodes_names.append(node.name)
len(node_names)
python
clip_text_quantized_dynamic_path = f"onnx_models/mobileclip_s2_text_opset{onnx_opset}_quant.onnx"
quantize_dynamic(clip_text_quantized_preprocessed_path, clip_text_quantized_dynamic_path, nodes_to_exclude=node_names[28])
python
mobileclip_text_quant_dyn_ort_sess = ort.InferenceSession(clip_text_quantized_dynamic_path)
text_onnx_quant_dyn_emb = mobileclip_text_quant_dyn_ort_sess.run(["output"], {"input": text_input.numpy().astype("int32")})[0][0]
text_onnx_quant_dyn_emb /= norm(text_onnx_quant_dyn_emb)
np.dot(text_onnx_quant_dyn_emb, text_onnx_emb)

Quantization Debugging (uncomment if you want to try it)

python
# exclude_amount = 1


# for i in range(25, 30, exclude_amount):
#     begin = i
#     end = min(i+exclude_amount, len(node_names))
    
#     clip_text_quantized_dynamic_debug_path = f"onnx_models/mobileclip_s2_text_opset{onnx_opset}_int8dyn_opt_debug.onnx"
#     quantize_dynamic(clip_text_quantized_preprocessed_path, clip_text_quantized_dynamic_debug_path, nodes_to_exclude=node_names[begin:end])
#     mobileclip_text_quant_dyn_ort_sess_debug = ort.InferenceSession(clip_text_quantized_dynamic_debug_path)
#     text_onnx_quant_dyn_emb_debug = mobileclip_text_quant_dyn_ort_sess_debug.run(["output"], {"input": text_input.numpy().astype("int32")})[0][0]
#     text_onnx_quant_dyn_emb_debug /= norm(text_onnx_quant_dyn_emb_debug)
#     sim_debug = np.dot(text_onnx_quant_dyn_emb_debug, text_onnx_emb)
#     print(f"Skipping nodes from {begin} to {end} resulted in a similarity of {sim_debug:.4f}")
python
node_names[28:29]

Test on a dataset of image captions. Before continuing, download the dataset from Kaggle and put it in the ../data folder

python
import csv
from tqdm import tqdm
import time
import copy
import matplotlib.pyplot as plt
python
captions = []

with open('../data/flickr8k_captions.txt', 'r', encoding='utf-8') as file:
    csv_reader = csv.reader(file)
    next(csv_reader)
    for row in csv_reader:
        captions.append(row[1])

print(len(captions))
print(captions[:5])

Test accuracy of quantized model quickly (uncomment code below)

python
test_size = 600
similarities = np.zeros(test_size)
mobileclip_text_quant_dyn_ort_sess = ort.InferenceSession(clip_text_quantized_dynamic_path)

for i, caption in tqdm(enumerate(captions[:test_size])):
    text_input_test = tokenizer([caption])
    text_emb_test = model.encode_text(text_input_test)[0].detach().numpy()
    text_emb_test /= norm(text_emb_test)
    text_onnx_test_emb = mobileclip_text_quant_dyn_ort_sess.run(["output"], {"input": text_input_test.numpy().astype("int32")})[0][0]
    text_onnx_test_emb /= norm(text_onnx_test_emb)
    similarities[i] = np.dot(text_onnx_test_emb, text_emb_test)
python
print(f"Mean similarity: {similarities.mean()}")
print(f"Standard deviation: {similarities.std()}")
print(f"Minimum similarity: {similarities.min()}")
print(f"Maximum similarity: {similarities.max()}")

Test accuracy of quantized model extensively (uncomment code below)

python
# captions_extensive = copy.deepcopy(captions)

# for i in range(10000):
#     captions_extensive[i] = captions_extensive[i] + " " + captions_extensive[i + 10000] + " " + captions_extensive[i + 20000] + " " + captions_extensive[i + 30000]
#     captions_extensive[i + 10000] = captions_extensive[i + 10000] + " " + captions_extensive[i + 20000] + " " + captions_extensive[i + 30000]
#     captions_extensive[i + 20000] = captions_extensive[i + 20000] + " " + captions_extensive[i + 30000]
# captions_extensive = captions_extensive[:40000]

# test_size = len(captions_extensive)
# similarities_extensive = np.zeros(test_size)
# mobileclip_text_quant_dyn_ort_sess = ort.InferenceSession(clip_text_quantized_dynamic_path)

# for i, caption in tqdm(enumerate(captions_extensive[:test_size])):
#     text_input_test = tokenizer([caption])
#     text_emb_test = model.encode_text(text_input_test)[0].detach().numpy()
#     text_emb_test /= norm(text_emb_test)
#     text_onnx_test_emb = mobileclip_text_quant_dyn_ort_sess.run(["output"], {"input": text_input_test.numpy().astype("int32")})[0][0]
#     text_onnx_test_emb /= norm(text_onnx_test_emb)
#     similarities_extensive[i] = np.dot(text_onnx_test_emb, text_emb_test)
python
# print(f"Mean similarity: {similarities_extensive.mean()}")
# print(f"Standard deviation: {similarities_extensive.std()}")
# print(f"Minimum similarity: {similarities_extensive.min()}")
# print(f"Maximum similarity: {similarities_extensive.max()}")
# print(f"Percentage of similarities above 0.99: {np.sum(similarities_extensive > 0.99) / len(similarities_extensive) * 100:.2f}%")
# print(f"Percentage of similarities above 0.995: {np.sum(similarities_extensive > 0.995) / len(similarities_extensive) * 100:.2f}%")

Investigating the MatMul excluded from quantization to improve performance (uncomment code below)

python
# quant_model = onnx.load(clip_text_opt_path)
# node_name = node_names[28] # /text_encoder/transformer.0/pre_norm_ffn/pre_norm_ffn.4/MatMul
# # use_node_name = matmul_nodes_names[8]
# use_node_name = node_name

# # Find the MatMul node
# special_matmul_node = None
# for node in quant_model.graph.node:
#     if node.op_type == 'MatMul' and node.name == use_node_name:
#         special_matmul_node = node
#         print(f"MatMul node found: {special_matmul_node.name}")
#         break

# if special_matmul_node is None:
#     raise ValueError(f"MatMul node with name '{use_node_name}' not found in the model.")

# # Get the weight tensor
# weight_name = special_matmul_node.input[1]
# special_weight_tensor = None
# for init in quant_model.graph.initializer:
#     if init.name == weight_name:
#         special_weight_tensor = init
#         break

# if special_weight_tensor is None:
#     raise ValueError(f"Weight tensor for MatMul node '{use_node_name}' not found.")

# special_weight_array = onnx.numpy_helper.to_array(special_weight_tensor)

# mean = np.mean(special_weight_array)
# std = np.std(special_weight_array)
# min_val = np.min(special_weight_array)
# max_val = np.max(special_weight_array)

# print(f"Statistical Analysis for MatMul node '{use_node_name}':")
# print(f"Mean: {mean}")
# print(f"Standard Deviation: {std}")
# print(f"Minimum: {min_val}")
# print(f"Maximum: {max_val}")
# print(f"Dynamic Range: {max_val - min_val}")

# plt.figure(figsize=(10, 6))
# plt.hist(special_weight_array.flatten(), bins=50, edgecolor='black')
# plt.title(f"Histogram of Weights for MatMul node '{use_node_name}'")
# plt.xlabel("Weight Value")
# plt.ylabel("Frequency")
# plt.show()

Test speed of quantized model

python
# time_test_size = 1000
# mobileclip_text_quant_dyn_ort_sess = ort.InferenceSession(clip_text_quantized_dynamic_path)
# times_unquantized = np.zeros(time_test_size)
# times_quantized = np.zeros(time_test_size)

# # Time of unquantized model
# print("Timing unquantized model...")
# for i, caption in tqdm(enumerate(captions[:time_test_size])):
#     text_input_test = tokenizer([caption])
#     start = time.time()
#     _ = model.encode_text(text_input_test)
#     end = time.time()
#     times_unquantized[i] = end - start

# # Time of quantized model
# print("Timing quantized model...")
# for i, caption in tqdm(enumerate(captions[:time_test_size])):
#     text_input_test = tokenizer([caption]).numpy().astype("int32")
#     start = time.time()
#     _ = mobileclip_text_quant_dyn_ort_sess.run(["output"], {"input": text_input_test})
#     end = time.time()
#     times_quantized[i] = end - start

# original_mean = times_unquantized.mean()
# original_std = times_unquantized.std()
# quantized_mean = times_quantized.mean()
# quantized_std = times_quantized.std()

# print(f"Original model: {original_mean:.6f} ± {original_std:.6f} seconds")
# print(f"Quantized model: {quantized_mean:.6f} ± {quantized_std:.6f} seconds")
# print(f"Speedup: {original_mean / quantized_mean:.2f}x")
python
!rm {clip_text_quantized_preprocessed_path}

Quantizing image model

Eventually got it to roughly 0.996 similarity with the original model, at a reduction of 54MB, from 143 to 89MB. Also not bad, but since it's less of a reduction and the resulting embeddings will be stored permanently we decided not to use it. Uncomment code below to restart investigation if wanted.

python
# image_node_names = []
# image_matmul_nodes_names = []
# image_conv_nodes_names = []
# for node in clip_image_opt.graph.node:
#     image_node_names.append(node.name)
#     if node.op_type == "MatMul":
#         image_matmul_nodes_names.append(node.name)
#     if node.op_type == "Conv":
#         image_conv_nodes_names.append(node.name)
# print(len(image_node_names))
# print(len(image_matmul_nodes_names))
# print(len(image_conv_nodes_names))
python
# clip_image_quantized_dynamic_path = f"onnx_models/mobileclip_s2_image_opset{onnx_opset}_int8_opt.onnx"
# exclude = list(set(image_node_names[:100] + image_conv_nodes_names))
# quantize_dynamic(clip_image_opt_path, clip_image_quantized_dynamic_path, weight_type=QuantType.QUInt8, nodes_to_exclude=exclude)

# mobileclip_image_quant_dyn_ort_sess = ort.InferenceSession(clip_image_quantized_dynamic_path)
# image_onnx_quant_dyn_emb = mobileclip_image_quant_dyn_ort_sess.run(["output"], {"input": image_singapore_onnx})[0][0]
# image_onnx_quant_dyn_emb /= norm(image_onnx_quant_dyn_emb)
# np.dot(image_onnx_quant_dyn_emb, image_emb)

Debug quantizations

python
# exclude_amount = 50
# exclude_for_sure = image_node_names[:100] + image_node_names[225:260] + image_node_names[280:300] + image_node_names[430:480] + image_node_names[510:560] + image_node_names[650:]

# image_test_quant = Image.open("../data/singapore.jpg").convert('RGB')
# image_test_quant_onnx = np.array(image_test_quant)

# clip_image_opt_sess = ort.InferenceSession(clip_image_opt_path)
# onnx_emb_quant_test = clip_image_opt_sess.run(None, {"input": image_test_quant_onnx})[0][0]
# onnx_emb_quant_test /= norm(onnx_emb_quant_test)


# for i in range(550, 600, exclude_amount):
#     begin = i
#     end = min(i+exclude_amount, len(image_node_names))
#     exclude = list(set(exclude_for_sure + image_node_names[begin:end]))
    
#     clip_image_quantized_dynamic_debug_path = f"onnx_models/mobileclip_s2_image_opset{onnx_opset}_int8dyn_opt_debug.onnx"
#     quantize_dynamic(clip_image_opt_path, clip_image_quantized_dynamic_debug_path, weight_type=QuantType.QUInt8, nodes_to_exclude=exclude)
#     mobileclip_image_quant_dyn_ort_sess_debug = ort.InferenceSession(clip_image_quantized_dynamic_debug_path)
#     image_onnx_quant_dyn_emb_debug = mobileclip_image_quant_dyn_ort_sess_debug.run(["output"], {"input": image_test_quant_onnx})[0][0]
#     image_onnx_quant_dyn_emb_debug /= norm(image_onnx_quant_dyn_emb_debug)
#     sim_debug = np.dot(image_onnx_quant_dyn_emb_debug, onnx_emb_quant_test)
#     print(f"Skipping nodes from {begin} to {end} resulted in a similarity of {sim_debug:.4f}")

Float16 conversion for Image model

https://onnxruntime.ai/docs/performance/model-optimizations/float16.html

python
from onnxconverter_common import convert_float_to_float16
python
check_nodes_names = []
skip_nodes_names = []
try_image_model = onnx.load(clip_image_opt_path)
for node in try_image_model.graph.node:
    check_nodes_names.append(node.name)
preprocess_nodes = check_nodes_names[:25]
python
clip_image_fp16 = convert_float_to_float16(try_image_model, keep_io_types=True, disable_shape_infer=True, node_block_list=preprocess_nodes)
clip_image_fp16_path = f"onnx_models/mobileclip_s2_image_opset{onnx_opset}_fp16.onnx"
onnx.save(clip_image_fp16, clip_image_fp16_path)

Test accuracy

python
image_onnx_input = np.array(Image.open("../data/singapore.jpg").convert('RGB'))
try_sess_options = ort.SessionOptions()
try_sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
# try_sess_options.inter_op_num_threads = 0
# try_sess_options.intra_op_num_threads = 0
# try_sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
# try_sess_options.enable_profiling = True
# try_sess_options.log_severity_level = 0 # Verbose
clip_image_fp16_sess = ort.InferenceSession(clip_image_fp16_path, try_sess_options)
clip_image_sess = ort.InferenceSession(clip_image_opt_path, try_sess_options)
image_onnx_fp16_emb = clip_image_fp16_sess.run(["output"], {"input": image_onnx_input})[0][0]
image_onnx_fp16_emb /= norm(image_onnx_fp16_emb)
image_onnx_emb = clip_image_sess.run(["output"], {"input": image_onnx_input})[0][0]
image_onnx_emb /= norm(image_onnx_emb)
print(np.dot(image_onnx_fp16_emb, image_onnx_emb))
print(image_onnx_emb[0:5])
print(image_onnx_fp16_emb[0:5])

Test speed

python
time_test_size = 100

begin_time_fp16 = time.time()
for i in tqdm(range(time_test_size)):
    _ = clip_image_fp16_sess.run(["output"], {"input": image_onnx_input})
end_time_fp16 = time.time()
time_fp16 = end_time_fp16 - begin_time_fp16

begin_time_opt = time.time()
for i in tqdm(range(time_test_size)):
    _ = clip_image_sess.run(["output"], {"input": image_onnx_input})
end_time_opt = time.time()
time_opt = end_time_opt - begin_time_opt



print(f"Optimized model: {time_opt:.6f} seconds, so {time_opt / time_test_size:.6f} seconds per inference")
print(f"FP16 model: {time_fp16:.6f} seconds, so {time_fp16 / time_test_size:.6f} seconds per inference")
print(f"Speed difference FP16: {time_opt / time_fp16:.2f}x")