Back to Supervision

Evaluating Alignment of Text-to-image Diffusion Models

docs/notebooks/evaluating-alignment-of-text-to-image-diffusion-models.ipynb

0.28.06.0 KB
Original Source

Evaluating Alignment of Text-to-image Diffusion Models


Click the Open in Colab button to run the cookbook on Google Colab.

Introduction

It is a common scenario to evaluate text-to-image models for its alignment to the prompt. One way to test it is to use a set of prompts, consisting of number of objects and their basic physical properties (e.g. color), to generate images and manually evaluate the results. This process can be greatly improved using object detection models.

Before you start

Let's make sure that we have access to GPU. We can use nvidia-smi command to do that. In case of any problems navigate to Edit -> Notebook settings -> Hardware accelerator, set it to GPU, and then click Save.

python
!nvidia-smi

Install required packages

In this cookbook, we'll leverage the following Python packages:

python
!pip install -q torch diffusers accelerate inference-gpu[yolo-world] dill git+https://github.com/openai/CLIP.git supervision

Imports

python
import itertools
import cv2
from diffusers import StableDiffusionXLPipeline
import numpy as np
from PIL import Image
import supervision as sv
import torch
from inference.models import YOLOWorld

Generating an image

We'll use SDXL model to generate our image. Let's initialize our pipeline first:

python
pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True,
).to("cuda")

In this example, we'll focus on generating an image of a black cat playing with a blue ball next to a parked white car. We don't care about the aesthetic aspect of the image.

python
PROMPT = "a black cat playing with a blue ball next to a parked white car, wide angle, photorealistic"
NEGATIVE_PROMPT = "low quality, blurred, text, illustration"
WIDTH, HEIGHT = 1024, 768
SEED = 9213799

image = pipeline(
    prompt=PROMPT,
    negative_prompt=NEGATIVE_PROMPT,
    generator=torch.manual_seed(SEED),
    width=WIDTH,
    height=HEIGHT,
).images[0]
image

Not bad! The results seem to be well-aligned with the prompt.

Detecting objects

Now, let's see how can we detect the objects automatically. For this, we'll use YOLO-World model from inference library.

python
model = YOLOWorld(model_id="yolo_world/l")

YOLO-World model allows us to define our own set of labels. Let's create it by combining lists of pre-defined colors and objects.

python
COLORS = ["green", "yellow", "black", "blue", "red", "white", "orange"]
OBJECTS = ["car", "cat", "ball", "dog", "tree", "house", "person"]
CLASSES = [f"{color} {obj}" for color, obj in itertools.product(COLORS, OBJECTS)]
print("Number of labels:", len(CLASSES))

Let's feed these labels into our model:

python
model.set_classes(CLASSES)

Time to detect some objects!

python
results = model.infer(image)

We'll convert the results to the sv.Detections format to enable features like filtering or annotations.

python
detections = sv.Detections.from_inference(results)

Speaking of which: we only care about strong detections, so we filter out ones that are below 0.6 confidence.

python
valid_detections = detections[detections.confidence >= 0.6]

A quick peek on the detected labels and their score:

python
labels = [
    f"{CLASSES[class_id]} {confidence:0.2f}"
    for class_id, confidence
    in zip(valid_detections.class_id, valid_detections.confidence)
]
labels

Visualizing results

Now, let's use the power of supervision to visualize them. Our output image is in Pillow format, but annotators can accept the image to be a BGR np.ndarray or pillow's PIL.Image.Image.

Time to define how we want our detections to be visualized. A combination of sv.BoundingBoxAnnotator and sv.LabelAnnotator should be perfect.

python
bounding_box_annotator = sv.BoundingBoxAnnotator(thickness=2)
label_annotator = sv.LabelAnnotator(text_thickness=1, text_scale=0.5,text_color=sv.Color.BLACK)

Finally, annotating our image is as simple as calling annotate methods from our annotators:

python
annotated_image = bounding_box_annotator.annotate(image, valid_detections)
annotated_image = label_annotator.annotate(annotated_image, valid_detections, labels)

sv.plot_image(annotated_image, (12, 12))

Testing it automatically

We can also test if all requested objects are in the generated image by comparing a set of ground-truth labels with predicted ones:

python
GROUND_TRUTH = {"black cat", "blue ball", "white car"}
prediction = {CLASSES[class_id] for class_id in valid_detections.class_id}

prediction.issubset(GROUND_TRUTH)

Using sv.Detections makes it super easy to do.

Next steps

In this tutorial you learned how to detect and visualize objects for a simple image generation evaluation study.

Having a pipeline capable of evaluating a single image, the natural next step should be to run it on a set of pre-defined scenarios and calculate metrics.