Back to Ray

Serving a Stable Diffusion Model with Ray Serve

doc/source/templates/03_serving_stable_diffusion/start.ipynb

1.13.19.5 KB
Original Source

Serving a Stable Diffusion Model with Ray Serve

Template SpecificationDescription
SummaryThis template loads a pretrained stable diffusion model from HuggingFace and serves it to a local endpoint as a Ray Serve deployment.
Time to RunAround 2 minutes to setup the models and generate your first image(s). Less than 10 seconds for every subsequent round of image generation (depending on the image size).
Minimum Compute RequirementsAt least 1 GPU node. The default is 4 nodes, each with 1 NVIDIA T4 GPU.
Cluster EnvironmentThis template uses a custom docker image built on top of the Anyscale-provided Ray image using Python 3.9: anyscale/ray:latest-py39-cu118. See the appendix in the README for more details.

By the end, we'll have an application that generates images using stable diffusion for a given prompt!

The application will look something like this:

text
Enter a prompt (or 'q' to quit):   twin peaks sf in basquiat painting style

Generating image(s)...

Generated 4 image(s) in 8.75 seconds to the directory: 58b298d9

Slot in your code below wherever you see the ✂️ icon to build off of this template!

The framework and data format used in this template can be easily replaced to suit your own application!

We'll start with some imports and initialize Ray:

python
from fastapi import FastAPI
from fastapi.responses import Response
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np
import os
import requests
import time
import uuid

import ray
from ray import serve

ray.init()

Deploy the Ray Serve application locally

First, we define the Ray Serve application with the model loading and inference logic. This includes setting up:

  • The /imagine API endpoint that we query to generate the image.
  • The stable diffusion model loaded inside a Ray Serve Deployment. We'll specify the number of model replicas to keep active in our Ray cluster. These model replicas can process incoming requests concurrently.

✂️ Replace these values to change the number of model replicas to serve, as well as the GPU resources required by each replica.

With more model replicas, more images can be generated in parallel!

python
NUM_REPLICAS: int = 4

python
if NUM_REPLICAS > ray.available_resources()["GPU"]:
    print(
        "Your cluster does not currently have enough resources to run with these settings. "
        "Consider decreasing the number of workers, or decreasing the resources needed "
        "per worker. Ignore this if your cluster auto-scales."
    )

First, we define the Ray Serve Deployment, which will load a stable diffusion model and perform inference with it.

✂️ Modify this block to load your own model, and change the generate method to perform your own online inference logic!

python
@serve.deployment(
    ray_actor_options={"num_gpus": 1},
    num_replicas=NUM_REPLICAS,
)
class StableDiffusionV2:
    def __init__(self):
        # <Replace with your own model loading logic>
        import torch
        from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline

        model_id = "stabilityai/stable-diffusion-2"
        scheduler = EulerDiscreteScheduler.from_pretrained(
            model_id, subfolder="scheduler"
        )
        self.pipe = StableDiffusionPipeline.from_pretrained(
            model_id, scheduler=scheduler, revision="fp16", torch_dtype=torch.float16
        )
        self.pipe = self.pipe.to("cuda")

    def generate(self, prompt: str, img_size: int = 776):
        # <Replace with your own model inference logic>
        assert len(prompt), "prompt parameter cannot be empty"
        image = self.pipe(prompt, height=img_size, width=img_size).images[0]
        return image

Next, we'll define the actual API endpoint to live at /imagine.

✂️ Modify this block to change the endpoint URL, response schema, and add any post-processing logic needed from your model output!

python
app = FastAPI()


@serve.deployment(num_replicas=1)
@serve.ingress(app)
class APIIngress:
    def __init__(self, diffusion_model_handle) -> None:
        self.handle = diffusion_model_handle

    @app.get(
        "/imagine",
        responses={200: {"content": {"image/png": {}}}},
        response_class=Response,
    )
    async def generate(self, prompt: str, img_size: int = 776):
        assert len(prompt), "prompt parameter cannot be empty"

        image = await self.handle.generate.remote(prompt, img_size=img_size)

        file_stream = BytesIO()
        image.save(file_stream, "PNG")
        return Response(content=file_stream.getvalue(), media_type="image/png")

Now, we deploy the Ray Serve application locally at http://localhost:8000!

python
entrypoint = APIIngress.bind(StableDiffusionV2.bind())

# Shutdown any existing Serve replicas, if they're still around.
serve.shutdown()
serve.run(entrypoint, name="serving_stable_diffusion_template")
print("Done setting up replicas! Now accepting requests...")

Make requests to the endpoint

Next, we'll build a simple client to submit prompts as HTTP requests to the local endpoint at http://localhost:8000/imagine.

Start the client script in the next few cells, and generate your first image!

python
endpoint = "http://localhost:8000/imagine"


@ray.remote(num_cpus=0)
def generate_image(prompt, image_size):
    req = {"prompt": prompt, "img_size": image_size}
    resp = requests.get(endpoint, params=req)
    return resp.content


def show_images(filenames):
    fig, axs = plt.subplots(1, len(filenames), figsize=(4 * len(filenames), 4))
    for i, filename in enumerate(filenames):
        ax = axs if len(filenames) == 1 else axs[i]
        ax.imshow(plt.imread(filename))
        ax.axis("off")
    plt.show()


def main(
    interactive: bool = False,
    prompt: str = "twin peaks sf in basquiat painting style",
    num_images: int = 4,
    image_size: int = 640,
):
    try:
        requests.get(endpoint, timeout=0.1)
    except Exception as e:
        raise RuntimeWarning(
            "Did you setup the Ray Serve model replicas with `serve.run` "
            "in a previous cell?"
        ) from e

    generation_times = []
    while True:
        prompt = (
            prompt
            if not interactive
            else input(f"\nEnter a prompt (or 'q' to quit):  ")
        )
        if prompt.lower() == "q":
            break

        print("\nGenerating image(s)...\n")
        start = time.time()

        # Make `num_images` requests to the endpoint at once!
        images = ray.get(
            [generate_image.remote(prompt, image_size) for _ in range(num_images)]
        )

        dirname = f"{uuid.uuid4().hex[:8]}"
        os.makedirs(dirname)
        filenames = []
        for i, image in enumerate(images):
            filename = os.path.join(dirname, f"{i}.png")
            with open(filename, "wb") as f:
                f.write(image)
            filenames.append(filename)

        elapsed = time.time() - start
        generation_times.append(elapsed)
        print(
            f"\nGenerated {len(images)} image(s) in {elapsed:.2f} seconds to "
            f"the directory: {dirname}\n"
        )
        show_images(filenames)
        if not interactive:
            break
    return np.mean(generation_times) if generation_times else -1

Once the stable diffusion model finishes generating your image(s), it will be included in the HTTP response body. The client saves all the images in a local directory for you to view, and they'll also show up in the notebook cell!

✂️ Replace this value to change the number of images to generate per prompt.

Each image will be generated starting from a different set of random noise, so you'll be able to see multiple options per prompt!

Try starting with NUM_IMAGES_PER_PROMPT equal to NUM_REPLICAS from earlier.

You can choose to run this interactively, or submit a single PROMPT.

python
NUM_IMAGES_PER_PROMPT: int = NUM_REPLICAS

# Control the output size: (IMAGE_SIZE, IMAGE_SIZE)
# The stable diffusion model requires `IMAGE_SIZE` to be a multiple of 8.
# NOTE: Generated image quality degrades rapidly if you reduce the size too much.
IMAGE_SIZE: int = 640

INTERACTIVE: bool = False
PROMPT = "twin peaks sf in basquiat painting style"

python
mean_generation_time = main(
    interactive=INTERACTIVE,
    prompt=PROMPT,
    num_images=NUM_IMAGES_PER_PROMPT,
    image_size=IMAGE_SIZE,
)

You've successfully served a stable diffusion model! You can modify this template and iterate your model deployment directly on your cluster within your Anyscale Workspace, testing with the local endpoint.

python
# Shut down the model replicas once you're done!
serve.shutdown()

Summary

This template used Ray Serve to serve many replicas of a stable diffusion model.

At a high level, this template showed how to:

  1. Define a Ray Serve deployment to load a HuggingFace model and perform inference.
  2. Set up a local endpoint to accept and route requests to the different model replicas.
  3. Make multiple requests in parallel to generate many images at a time.

See this getting started guide for a more detailed walkthrough of Ray Serve.