docs/notebooks/evaluating-alignment-of-text-to-image-diffusion-models.ipynb
Click the Open in Colab button to run the cookbook on Google Colab.
It is a common scenario to evaluate text-to-image models for its alignment to the prompt. One way to test it is to use a set of prompts, consisting of number of objects and their basic physical properties (e.g. color), to generate images and manually evaluate the results. This process can be greatly improved using object detection models.
Let's make sure that we have access to GPU. We can use nvidia-smi command to do that. In case of any problems navigate to Edit -> Notebook settings -> Hardware accelerator, set it to GPU, and then click Save.
!nvidia-smi
In this cookbook, we'll leverage the following Python packages:
!pip install -q torch diffusers accelerate inference-gpu[yolo-world] dill git+https://github.com/openai/CLIP.git supervision
import itertools
import cv2
from diffusers import StableDiffusionXLPipeline
import numpy as np
from PIL import Image
import supervision as sv
import torch
from inference.models import YOLOWorld
We'll use SDXL model to generate our image. Let's initialize our pipeline first:
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
).to("cuda")
In this example, we'll focus on generating an image of a black cat playing with a blue ball next to a parked white car. We don't care about the aesthetic aspect of the image.
PROMPT = "a black cat playing with a blue ball next to a parked white car, wide angle, photorealistic"
NEGATIVE_PROMPT = "low quality, blurred, text, illustration"
WIDTH, HEIGHT = 1024, 768
SEED = 9213799
image = pipeline(
prompt=PROMPT,
negative_prompt=NEGATIVE_PROMPT,
generator=torch.manual_seed(SEED),
width=WIDTH,
height=HEIGHT,
).images[0]
image
Not bad! The results seem to be well-aligned with the prompt.
Now, let's see how can we detect the objects automatically. For this, we'll use YOLO-World model from inference library.
model = YOLOWorld(model_id="yolo_world/l")
YOLO-World model allows us to define our own set of labels. Let's create it by combining lists of pre-defined colors and objects.
COLORS = ["green", "yellow", "black", "blue", "red", "white", "orange"]
OBJECTS = ["car", "cat", "ball", "dog", "tree", "house", "person"]
CLASSES = [f"{color} {obj}" for color, obj in itertools.product(COLORS, OBJECTS)]
print("Number of labels:", len(CLASSES))
Let's feed these labels into our model:
model.set_classes(CLASSES)
Time to detect some objects!
results = model.infer(image)
We'll convert the results to the sv.Detections format to enable features like filtering or annotations.
detections = sv.Detections.from_inference(results)
Speaking of which: we only care about strong detections, so we filter out ones that are below 0.6 confidence.
valid_detections = detections[detections.confidence >= 0.6]
A quick peek on the detected labels and their score:
labels = [
f"{CLASSES[class_id]} {confidence:0.2f}"
for class_id, confidence
in zip(valid_detections.class_id, valid_detections.confidence)
]
labels
Now, let's use the power of supervision to visualize them. Our output image is in Pillow format, but annotators can accept the image to be a BGR np.ndarray or pillow's PIL.Image.Image.
Time to define how we want our detections to be visualized. A combination of sv.BoundingBoxAnnotator and sv.LabelAnnotator should be perfect.
bounding_box_annotator = sv.BoundingBoxAnnotator(thickness=2)
label_annotator = sv.LabelAnnotator(text_thickness=1, text_scale=0.5,text_color=sv.Color.BLACK)
Finally, annotating our image is as simple as calling annotate methods from our annotators:
annotated_image = bounding_box_annotator.annotate(image, valid_detections)
annotated_image = label_annotator.annotate(annotated_image, valid_detections, labels)
sv.plot_image(annotated_image, (12, 12))
We can also test if all requested objects are in the generated image by comparing a set of ground-truth labels with predicted ones:
GROUND_TRUTH = {"black cat", "blue ball", "white car"}
prediction = {CLASSES[class_id] for class_id in valid_detections.class_id}
prediction.issubset(GROUND_TRUTH)
Using sv.Detections makes it super easy to do.
In this tutorial you learned how to detect and visualize objects for a simple image generation evaluation study.
Having a pipeline capable of evaluating a single image, the natural next step should be to run it on a set of pre-defined scenarios and calculate metrics.