Back to Ente

Prepping YOLOv5Face model for use in Ente

infra/ml/playground/YOLOv5Face/yoloface_onnx.ipynb

2.0.3425.5 KB
Original Source

Prepping YOLOv5Face model for use in Ente

Paper | Github

Setting up Pytorch weights and source code

Please manually put the Pytorch .pt weights in the pytorch_weights directory.

python
model_weights_path = "pytorch_weights/yolov5s_face.pt"
models_path = "onnx_models/"
python
!mkdir yoloface_repo
%cd yoloface_repo
!git clone https://github.com/deepcam-cn/yolov5-face.git
%cd ..
!cp -r yoloface_repo/yolov5-face/models/ models/
!cp -r yoloface_repo/yolov5-face/utils/ utils/
!rm -rf yoloface_repo

Imports

python
# Libraries
import torch
import torch.nn as nn
from PIL import Image
import json
import numpy as np
import onnx
import onnxruntime as ort
print(ort.__version__)

# Source code
from models.common import Conv, ShuffleV2Block
from models.experimental import attempt_load
from utils.activations import Hardswish, SiLU
from utils.general import set_logging, check_img_size

Export to ONNX

python
onnx_opset = 18
img_size = [640, 640]
batch_size = 1
dynamic_shapes = False

# Load PyTorch model
model = attempt_load(
    model_weights_path, map_location=torch.device("cpu")
)  # load FP32 model
delattr(model.model[-1], "anchor_grid")
model.model[-1].anchor_grid = [
    torch.zeros(1)
] * 3  # nl=3 number of detection layers
model.model[-1].export_cat = True
model.eval()
labels = model.names

# Checks
gs = int(max(model.stride))  # grid size (max stride)
img_size = [
    check_img_size(x, gs) for x in img_size
]  # verify img_size are gs-multiples

# Test input
img = torch.zeros(batch_size, 3, *img_size)

# Update model
for k, m in model.named_modules():
    m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatibility
    if isinstance(m, Conv):  # assign export-friendly activations
        if isinstance(m.act, nn.Hardswish):
            m.act = Hardswish()
        elif isinstance(m.act, nn.SiLU):
            m.act = SiLU()
    if isinstance(m, ShuffleV2Block):  # shufflenet block nn.SiLU
        for i in range(len(m.branch1)):
            if isinstance(m.branch1[i], nn.SiLU):
                m.branch1[i] = SiLU()
        for i in range(len(m.branch2)):
            if isinstance(m.branch2[i], nn.SiLU):
                m.branch2[i] = SiLU()
y = model(img)  # dry run

# ONNX export
print("\nStarting ONNX export with onnx %s..." % onnx.__version__)
onnx_model_export_path = models_path + model_weights_path.replace(".pt", ".onnx").split('/')[-1]
model.fuse()  
input_names = ["input"]
output_names = ["output"]
torch.onnx.export(
    model,
    img,
    onnx_model_export_path,
    verbose=False,
    opset_version=onnx_opset,
    input_names=input_names,
    output_names=output_names,
    dynamic_axes=(
        {"input": {0: "batch"}, "output": {0: "batch"}} if dynamic_shapes else None
    ),
)

# Checks
onnx_model = onnx.load(onnx_model_export_path)  # load onnx model
onnx.checker.check_model(onnx_model)  # check onnx model

# onnx infer
providers = ["CPUExecutionProvider"]
session = ort.InferenceSession(onnx_model_export_path, providers=providers)
im = img.cpu().numpy().astype(np.float32)  # torch to numpy
y_onnx = session.run(
    [session.get_outputs()[0].name], {session.get_inputs()[0].name: im}
)[0]
print("pred's shape is ", y_onnx.shape)
print("max(|torch_pred - onnx_pred|) =", abs(y.cpu().numpy() - y_onnx).max())
python
!rm -rf models/
!rm -rf utils/

Altering ONNX model

Add preprocessing inside model

python
from onnxruntime_extensions.tools.pre_post_processing import PrePostProcessor, create_named_value, Resize, ImageBytesToFloat, Unsqueeze, Debug, LetterBox, ChannelsLastToChannelsFirst
python
inputs = [create_named_value("input_to_preprocess", onnx.TensorProto.UINT8, ["H", "W", "C"])]

pipeline = PrePostProcessor(inputs, onnx_opset)

pipeline.add_pre_processing(
    [
        Resize(640, layout= "HWC", policy="not_larger"), # Resize to 640, maintaining aspect ratio and letting largest dimension not exceed 640 (so smaller dimension will be <= 640)
        # Debug(),
        LetterBox((640, 640), layout="HWC", fill_value=114),  # Add padding to make the image actually always 640x640,
        ChannelsLastToChannelsFirst(),  # Convert to CHW
        # Debug(),
        ImageBytesToFloat(),  # Convert to float in range 0..1 by dividing uint8 values by 255
        # Debug(),
        Unsqueeze([0]),  # add batch, CHW --> 1CHW
        # Debug(),
    ]
)

# pipeline.add_post_processing()
onnx_model_prepro = pipeline.run(onnx_model)
onnx.checker.check_model(onnx_model_prepro)

To debug and visually inspect the preprocessing, please uncomment the Debug() statements in above block and run it again, and then uncomment and run the code in the block below:

python
# onnx.save(onnx_model_prepro, "yolov5s_face_prepro.onnx")

# image_singapore = Image.open("../data/singapore.jpg").convert('RGB')
# image_singapore_onnx = np.array(image_singapore)
# print(image_singapore_onnx.shape)
# print(type(image_singapore_onnx))
# print(image_singapore_onnx.dtype)

# ort_session = ort.InferenceSession("yolov5s_face_prepro.onnx")
# test = ort_session.run(None, {"input_to_preprocess": image_singapore_onnx})

# preprocessed = test[4]
# print(preprocessed.shape)
# print(type(preprocessed))

# # import matplotlib#.pyplot as plt
# from IPython.display import display
# # matplotlib.use('TkAgg')

# displayable_array = preprocessed.reshape(3, 640, 640).transpose((1, 2, 0))
# # Display the image
# # matplotlib.pyplot.imshow(displayable_array)
# # matplotlib.pyplot.axis('off')  
# # matplotlib.pyplot.show()
# display(Image.fromarray((displayable_array * 255).astype(np.uint8)))

Add slice operator for use of RGBA input

python
# Create a new input with flexible channel dimension
new_input = onnx.helper.make_tensor_value_info(
    name="input",
    elem_type=onnx.TensorProto.UINT8,
    shape=["H", "W", "C"],  
)

# Create constant tensors for starts, ends, and axes and use them to create a Slice node
starts_tensor = onnx.helper.make_tensor(
    name="starts",
    data_type=onnx.TensorProto.INT64,
    dims=[1],
    vals=np.array([0], dtype=np.int64)
)
ends_tensor = onnx.helper.make_tensor(
    name="ends",
    data_type=onnx.TensorProto.INT64,
    dims=[1],
    vals=np.array([3], dtype=np.int64)
)
axes_tensor = onnx.helper.make_tensor(
    name="axes",
    data_type=onnx.TensorProto.INT64,
    dims=[1],
    vals=np.array([2], dtype=np.int64)
)
slice_node = onnx.helper.make_node(
    "Slice",
    inputs=["input", "starts", "ends", "axes"],
    outputs=["sliced_input"],
    name="slice_rgba_input_node",
)
# Combine initializers
initializers = [starts_tensor, ends_tensor, axes_tensor] + list(onnx_model_prepro.graph.initializer)

# Get the name of the original input
original_input_name = onnx_model_prepro.graph.input[0].name

# Make new graph by adding the new input and Slice node to the old graph
graph = onnx.helper.make_graph(
    [slice_node] + list(onnx_model_prepro.graph.node),  # Prepend Slice node to existing nodes
    onnx_model_prepro.graph.name,
    [new_input] + list(onnx_model_prepro.graph.input)[1:],  # Replace first input, keep others
    list(onnx_model_prepro.graph.output),
    initializer=initializers,
    value_info=onnx_model_prepro.graph.value_info,
)

# Create the new model
onnx_model_rgba = onnx.helper.make_model(
    graph,
    opset_imports=[onnx.helper.make_opsetid("", onnx_opset)]
)

# Update the input names in the rest of the model
for node in onnx_model_rgba.graph.node:
    for i, input_name in enumerate(node.input):
        if input_name == original_input_name:
            node.input[i] = "sliced_input"

# Save the new model
onnx.checker.check_model(onnx_model_rgba)
onnx_model_rgba_path = onnx_model_export_path[:-5] + "_rgba.onnx"
onnx.save(onnx_model_rgba, onnx_model_rgba_path)
python
# image = Image.open("../data/man.jpeg").convert('RGBA')
# image_onnx = np.array(image)
# print(image_onnx.shape)
# print(type(image_onnx))
# print(image_onnx.dtype)

# ort_session = ort.InferenceSession("yolov5s_face_rgba.onnx")
# test = ort_session.run(None, {"input": image_onnx})
# print(test[0].shape)
# scores_output = test[0][:,:,4]
# print(f"Highest score: {scores_output.max()}")

Add post-processing inside the model

Let's first rename the output of the model so we can name the post-processed output as output. Then we have to split [1, 25200, 16] into [1, 25200, 4], [1, 25200, 1], [1, 25200, 11] (i.e. [1, detections, bbox], [1, detections, score], [1, detections, landmarks]) named boxes, scores, masks.

python
# Add a Split operator at the end so that it be used with the SelectBestBoundingBoxesByNMS operator
num_det = 25200
graph = onnx_model_rgba.graph

# Let's first rename the output of the model so we can name the post-processed output as `output`
for node in onnx_model_rgba.graph.node:
    for i, output_name in enumerate(node.output):
        if output_name == "output":
            node.output[i] = "og_output"
og_output = onnx.helper.make_tensor_value_info(
    name="og_output",
    elem_type=onnx.TensorProto.FLOAT,
    shape=[1, num_det, 16],  
)

# Create the split node
boxes_output = onnx.helper.make_tensor_value_info(
    name="boxes_unsqueezed",
    elem_type=onnx.TensorProto.FLOAT,
    shape=[1, num_det, 4],  
)
scores_output = onnx.helper.make_tensor_value_info(
    name="scores_unsqueezed",
    elem_type=onnx.TensorProto.FLOAT,
    shape=[1, num_det, 1],  
)
masks_output = onnx.helper.make_tensor_value_info(
    name="masks_unsqueezed",
    elem_type=onnx.TensorProto.FLOAT,
    shape=[1, num_det, 11],  
)
splits_tensor = onnx.helper.make_tensor(
    name="splits",
    data_type=onnx.TensorProto.INT64,
    dims=[3],
    vals=np.array([4, 1, 11], dtype=np.int64)
)
split_node = onnx.helper.make_node(
        "Split",
        inputs=["og_output", "splits"],
        outputs=["boxes_unsqueezed", "scores_unsqueezed", "masks_unsqueezed"],
        name="split_og_output",
        axis=2,
)

# Combine initializers
initializers = list(graph.initializer) + [splits_tensor]

# Make new graph by adding the new outputs and Split node to the old graph
graph = onnx.helper.make_graph(
    list(graph.node) + [split_node],  # Append split node to existing nodes
    graph.name,
    list(graph.input), 
    [boxes_output, scores_output, masks_output],
    initializer=initializers,
    value_info=graph.value_info,
)

# Create the new model
onnx_model_split = onnx.helper.make_model(
    graph,
    opset_imports=[onnx.helper.make_opsetid("", onnx_opset)]
)

# Save the new model
onnx.checker.check_model(onnx_model_split)
onnx_model_split_path = onnx_model_export_path[:-5] + "_split.onnx"
onnx.save(onnx_model_split, onnx_model_split_path)

Now we can run NMS on these splitted outputs using NonMaxSuppression operator

python
num_det = 25200
graph = onnx_model_split.graph
nodes = list(graph.node)
outputs = list(graph.output)
initializers = list(graph.initializer)
original_output = graph.output[0]

# Create the Transpose node for the scores (since NMS requires the scores to be in the middle dimension for some reason)
transpose_node_score = onnx.helper.make_node(
        "Transpose",
        inputs=["scores_unsqueezed"],
        outputs=["scores_transposed"],
        name="transpose_scores",
        perm=[0, 2, 1],
)
nodes.append(transpose_node_score)

# Create the NMS node
nms_indices = onnx.helper.make_tensor_value_info("nms_indices", onnx.TensorProto.INT64, shape=["detections", 3])
max_output = onnx.helper.make_tensor("max_output",onnx.TensorProto.INT64, [1], np.array([100], dtype=np.int64))
iou_threshold = onnx.helper.make_tensor("iou_threshold",onnx.TensorProto.FLOAT, [1], np.array([0.4], dtype=np.float32))
score_threshold = onnx.helper.make_tensor("score_threshold",onnx.TensorProto.FLOAT, [1], np.array([0.6], dtype=np.float32))
initializers = initializers + [max_output, iou_threshold, score_threshold]
nms_node = onnx.helper.make_node(
        "NonMaxSuppression",
        inputs=["boxes_unsqueezed", "scores_transposed", "max_output", "iou_threshold", "score_threshold"],
        outputs=["nms_indices"],
        name="perform_nms",
        center_point_box=1,
)
nodes.append(nms_node)
outputs.append(nms_indices)

# Make new graph by adding the new outputs and Split node to the old graph
graph = onnx.helper.make_graph(
    nodes,
    graph.name,
    list(graph.input), 
    outputs,
    initializer=initializers,
    value_info=graph.value_info,
)

# Create the new model
onnx_model_nms = onnx.helper.make_model(
    graph,
    opset_imports=[onnx.helper.make_opsetid("", onnx_opset)]
)

# Save the new model
onnx.checker.check_model(onnx_model_nms)
onnx_model_nms_path = onnx_model_export_path[:-5] + "_nms.onnx"
onnx.save(onnx_model_nms, onnx_model_nms_path)
python
# image = Image.open("../data/man.jpeg").convert('RGBA')
# image_onnx = np.array(image)

# ort_session = ort.InferenceSession("yolov5s_face_nms.onnx")
# test = ort_session.run(None, {"input": image_onnx})
# print(test[3].shape)
# print(test[3])
# print(test[1][0, 24129, 0])

Now we need to add some Squeeze, Slice and Gather nodes so handle the NMS given indices properly. The goal is that the final output is a very simple array of shape (detections, 16) of only the relevant detections.

python
num_det = 25200
graph = onnx_model_nms.graph
nodes = list(graph.node)
outputs = list(graph.output)
initializers = list(graph.initializer)
original_output = graph.output[0]

# Create Slide node to slice the NMS indices from (detections, 3) to (detections, 1) by taking the third column
sliced_indices = onnx.helper.make_tensor_value_info("sliced_indices", onnx.TensorProto.INT64, shape=["detections", 1])
outputs.append(sliced_indices)
starts_slice_tensor = onnx.helper.make_tensor(
    name="starts_slice_tensor",
    data_type=onnx.TensorProto.INT64,
    dims=[1],
    vals=np.array([2], dtype=np.int64)
)
ends_slice_tensor = onnx.helper.make_tensor(
    name="ends_slice_tensor",
    data_type=onnx.TensorProto.INT64,
    dims=[1],
    vals=np.array([3], dtype=np.int64)
)
axes_slice_tensor = onnx.helper.make_tensor(
    name="axes_slice_tensor",
    data_type=onnx.TensorProto.INT64,
    dims=[1],
    vals=np.array([1], dtype=np.int64)
)
initializers = initializers + [starts_slice_tensor, ends_slice_tensor, axes_slice_tensor]
slice_node = onnx.helper.make_node(
    "Slice",
    inputs=["nms_indices", "starts_slice_tensor", "ends_slice_tensor", "axes_slice_tensor"],
    outputs=["sliced_indices"],
    name="slice_nms_indices",
)
nodes.append(slice_node)

# Create Squeeze node to squeeze the sliced indices
squeezed_indices = onnx.helper.make_tensor_value_info("squeezed_indices", onnx.TensorProto.INT64, shape=["detections"])
outputs.append(squeezed_indices)
squeeze_slice_tensor = onnx.helper.make_tensor("squeeze_slice_axis",onnx.TensorProto.INT64, [1], np.array([1], dtype=np.int64))
initializers.append(squeeze_slice_tensor)
squeeze_slice_node = onnx.helper.make_node(
        "Squeeze",
        inputs=["sliced_indices", "squeeze_slice_axis"],
        outputs=["squeezed_indices"],
        name="squeeze_sliced_indices",
)
nodes.append(squeeze_slice_node)

# Create Squeeze node to squeeze the original output
squeezed_output = onnx.helper.make_tensor_value_info("squeezed_output", onnx.TensorProto.FLOAT, shape=[25200, 16])
outputs.append(squeezed_output)
squeeze_tensor = onnx.helper.make_tensor("squeeze_axis",onnx.TensorProto.INT64, [1], np.array([0], dtype=np.int64))
initializers.append(squeeze_tensor)
squeeze_node = onnx.helper.make_node(
        "Squeeze",
        inputs=["og_output", "squeeze_axis"],
        outputs=["squeezed_output"],
        name="squeeze_output",
)
nodes.append(squeeze_node)


# Create Gather node to gather the relevant NMS indices from the original output
postpro_output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape=["detections", 16])
outputs.append(postpro_output)
gather_node = onnx.helper.make_node(
    "Gather",
    inputs=["squeezed_output", "squeezed_indices"],
    outputs=["output"],
    name="gather_output",
)
nodes.append(gather_node)


# Make the new graph
graph = onnx.helper.make_graph(
    nodes,
    graph.name,
    list(graph.input), 
    [postpro_output],
    initializer=initializers,
    value_info=graph.value_info,
)

# Create the new model
onnx_model_prepostpro = onnx.helper.make_model(
    graph,
    opset_imports=[onnx.helper.make_opsetid("", onnx_opset)]
)

# Save the new model
onnx.checker.check_model(onnx_model_prepostpro)
onnx_model_prepostpro_path = onnx_model_export_path[:-5] + "_prepostpro.onnx"
onnx.save(onnx_model_prepostpro, onnx_model_prepostpro_path)
python
# image = Image.open("../data/people.jpeg").convert('RGBA')
# image_onnx = np.array(image)

# ort_session = ort.InferenceSession("yolov5s_face_prepostpro.onnx")
# test = ort_session.run(None, {"input": image_onnx})
# test[0].shape

Optimize model

python
# define path og model and sim model
onnx_model_sim_path = onnx_model_export_path[:-5] + f"_opset{onnx_opset}_rgba_sim.onnx"

Simplify the model

python
!onnxsim {onnx_model_prepostpro_path} {onnx_model_sim_path}
python
# !onnxsim yolov5s_face_prepostpro.onnx yolov5s_face_opset18_rgba_sim.onnx

Optimize the graph

python
opt_sess_options = ort.SessionOptions()

opt_sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
opt_sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC

onnx_model_opt_path = onnx_model_export_path[:-5] + f"_opset{onnx_opset}_rgba_opt.onnx"
opt_sess_options.optimized_model_filepath = onnx_model_opt_path

opt_session = ort.InferenceSession(onnx_model_sim_path, opt_sess_options)

Prevent splits initializer issue

For some weird reason the model can give issues on iOS when there's an initializer named "splits". So to prevent that we check and rename any such initializer

python
current_model = onnx.load(onnx_model_opt_path)
python
def find_duplicates(name_list):
    seen = set()
    duplicates = set()
    
    for name in name_list:
        if name in seen:
            duplicates.add(name)
        else:
            seen.add(name)
    
    return list(duplicates)

# Get the list of initializers
initializers = current_model.graph.initializer
init_names = [init.name for init in initializers]

# If you want to store the initializers and their names in a dictionary
initializer_dict = {init.name: init for init in initializers}
init_names = [init.name for init in initializers]

print(f"splits initializer: \n {initializer_dict["splits"]}")

duplicate_names = find_duplicates(init_names)

print("Duplicate names:", duplicate_names)
python
def rename_initializer(model, old_name, new_name):
    for initializer in model.graph.initializer:
        if initializer.name == old_name:
            initializer.name = new_name
            break
    
    # Update any references to this initializer in the graph inputs
    for input in model.graph.input:
        if input.name == old_name:
            input.name = new_name
    
    # Update references in nodes
    for node in model.graph.node:
        for i, input_name in enumerate(node.input):
            if input_name == old_name:
                node.input[i] = new_name
python
rename_initializer(current_model, "splits", "splits_initializer_unique")

# Save the modified model
onnx_model_opt_with_splits_path = onnx_model_opt_path
onnx_model_opt_path = onnx_model_opt_path[:-5] + "_nosplits.onnx"
onnx.save(current_model, onnx_model_opt_path)

Add metadata to model

https://onnx.ai/onnx/intro/python.html#opset-and-metadata

python
new_yolo_face_model = onnx.load(onnx_model_opt_path)
new_yolo_face_model.producer_name = "EnteYOLOv5Face"
new_yolo_face_model.doc_string = "YOLOv5 Face detector with built-in pre- and post-processing. Accepts both RGB and RGBA raw bytes input (uint8) in HWC format. Outputs the relevant detections in the format (detections, 16) where the first 4 values are the bounding box coordinates, the fifth is the confidence score, and the rest are the landmarks."
new_yolo_face_model.graph.doc_string = ""
new_yolo_face_model.graph.name = "SliceRGB+Resize+LetterBox+ToFloat+Unsqueeze+YOLOv5Face+NMS+Slice+Gather"
onnx.save(new_yolo_face_model, onnx_model_opt_path)
python
!rm {onnx_model_export_path}
!rm {onnx_model_rgba_path}
!rm {onnx_model_split_path}
!rm {onnx_model_nms_path}
!rm {onnx_model_prepostpro_path}
!rm {onnx_model_sim_path}
!rm {onnx_model_opt_with_splits_path}

Tune some settings

python
# from tqdm import tqdm
# import time
python
# image = Image.open("../data/people.jpeg").convert('RGBA')
# image_onnx = np.array(image)
# time_test_size = 500

# sess_options1 = ort.SessionOptions()
# sess_options1.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
# # sess_options.enable_profiling = True
# # sess_options.log_severity_level = 0 # Verbose
# sess_options1.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
# ort_session1 = ort.InferenceSession(onnx_model_opt_path, sess_options1)

# begin_time_1 = time.time()
# for i in tqdm(range(time_test_size)):
#     _ = ort_session1.run(None, {"input": image_onnx})
# end_time_1 = time.time()
# time_1 = end_time_1 - begin_time_1


# sess_options2 = ort.SessionOptions()
# sess_options2.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
# # sess_options.enable_profiling = True
# # sess_options.log_severity_level = 0 # Verbose
# sess_options2.inter_op_num_threads = 4
# # sess_options2.intra_op_num_threads = 4
# sess_options2.execution_mode = ort.ExecutionMode.ORT_PARALLEL
# ort_session2 = ort.InferenceSession(onnx_model_opt_path, sess_options2, providers=["CPUExecutionProvider"])

# begin_time_2 = time.time()
# for i in tqdm(range(time_test_size)):
#     _ = ort_session2.run(None, {"input": image_onnx})
# end_time_2 = time.time()
# time_2 = end_time_2 - begin_time_2

# print(f"Time for first execution: {time_1}")
# print(f"Time for second execution: {time_2}")

So lessons:

  1. Use sequential execution
  2. Use extended optimizations
  3. Number of inter op doesn't have significant impact
  4. Number of intra op doesn't have significant impact

One final test:

python
image = Image.open("../data/man.jpeg").convert('RGBA')
imageWidth, imageHeight = image.size
inputWidth, inputHeight = 640, 640
print(imageWidth, imageHeight)
image_onnx = np.array(image)

sess_options1 = ort.SessionOptions()
sess_options1.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
# sess_options.enable_profiling = True
# sess_options.log_severity_level = 0 # Verbose
# sess_options1.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
ort_session = ort.InferenceSession(onnx_model_opt_path)
raw_detection = ort_session.run(None, {"input": image_onnx})[0][0]
print(raw_detection.shape)
raw_detection
python
from PIL import Image, ImageDraw
from IPython.display import display

def display_face_detection(image_path, face_box, landmarks):
    # Open the image
    img = Image.open(image_path)
    
    # Create a draw object
    draw = ImageDraw.Draw(img)
    
    # Draw the bounding box
    draw.rectangle(face_box, outline="red", width=2)
    
    # Draw the landmark points
    for point in landmarks:
        x, y = point
        radius = 3
        draw.ellipse([x-radius, y-radius, x+radius, y+radius], fill="blue")
    
    # Display the image
    display(img)
python
def correct_detection_and_display(image_path, raw_detection, imageWidth, imageHeight, inputWidth, inputHeight):

    # Create the raw relative bounding box and landmarks
    box = [0, 0, 0, 0]
    box[0] = (raw_detection[0] - raw_detection[2] / 2) / inputWidth
    box[1] = (raw_detection[1] - raw_detection[3] / 2) / inputHeight
    box[2] = (raw_detection[0] + raw_detection[2] / 2) / inputWidth
    box[3] = (raw_detection[1] + raw_detection[3] / 2) / inputHeight
    landmarks = [(0, 0), (0, 0), (0, 0), (0, 0), (0, 0)]
    i = 0
    for x, y in zip(raw_detection[5:15:2], raw_detection[6:15:2]):
        landmarks[i] = (x / inputWidth, y / inputHeight)
        i += 1

    # Correct the bounding box and landmarks for letterboxing during preprocessing
    scale = min(inputWidth / imageWidth, inputHeight / imageHeight)
    scaledWidth = round(imageWidth * scale)
    scaledHeight = round(imageHeight * scale)
    print(f"scaledWidth: {scaledWidth}, scaledHeight: {scaledHeight}")

    halveDiffX = (inputWidth - scaledWidth) / 2
    halveDiffY = (inputHeight - scaledHeight) / 2
    print(f"halveDiffX: {halveDiffX}, halveDiffY: {halveDiffY}")
    scaleX = inputHeight / scaledWidth
    scaleY = inputHeight / scaledHeight
    translateX = - halveDiffX / inputWidth
    translateY = - halveDiffY / inputHeight
    print(f"scaleX: {scaleX}, scaleY: {scaleY}")
    print(f"translateX: {translateX}, translateY: {translateY}")

    box[0] = (box[0] + translateX) * scaleX
    box[1] = (box[1] + translateY) * scaleY
    box[2] = (box[2] + translateX) * scaleX
    box[3] = (box[3] + translateY) * scaleY

    for i in range(5):
        landmarks[i] = ((landmarks[i][0] + translateX) * scaleX, (landmarks[i][1] + translateY) * scaleY)

    # Convert the bounding box and landmarks to absolute values
    box = [box[0] * imageWidth, box[1] * imageHeight, box[2] * imageWidth, box[3] * imageHeight]
    landmarks = [(x * imageWidth, y * imageHeight) for x, y in landmarks]

    print("Bounding box:", box)
    print("Landmarks:", landmarks)

    display_face_detection(image_path, box, landmarks)

python
image_path = "../data/man.jpeg"
# face_box = (50, 10, 100, 100)  # (left, top, right, bottom)
# landmarks = [
#     (30, 30),  # Left eye
#     (80, 30),  # Right eye
#     (55, 50),  # Nose
#     (35, 80),  # Left mouth corner
#     (75, 80)   # Right mouth corner
# ]

correct_detection_and_display(image_path, raw_detection, imageWidth, imageHeight, inputWidth, inputHeight)