doc/source/templates/03_serving_stable_diffusion/start.ipynb
| Template Specification | Description |
|---|---|
| Summary | This template loads a pretrained stable diffusion model from HuggingFace and serves it to a local endpoint as a Ray Serve deployment. |
| Time to Run | Around 2 minutes to setup the models and generate your first image(s). Less than 10 seconds for every subsequent round of image generation (depending on the image size). |
| Minimum Compute Requirements | At least 1 GPU node. The default is 4 nodes, each with 1 NVIDIA T4 GPU. |
| Cluster Environment | This template uses a custom docker image built on top of the Anyscale-provided Ray image using Python 3.9: anyscale/ray:latest-py39-cu118. See the appendix in the README for more details. |
By the end, we'll have an application that generates images using stable diffusion for a given prompt!
The application will look something like this:
Enter a prompt (or 'q' to quit): twin peaks sf in basquiat painting style
Generating image(s)...
Generated 4 image(s) in 8.75 seconds to the directory: 58b298d9
Slot in your code below wherever you see the ✂️ icon to build off of this template!
The framework and data format used in this template can be easily replaced to suit your own application!
We'll start with some imports and initialize Ray:
from fastapi import FastAPI
from fastapi.responses import Response
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np
import os
import requests
import time
import uuid
import ray
from ray import serve
ray.init()
First, we define the Ray Serve application with the model loading and inference logic. This includes setting up:
/imagine API endpoint that we query to generate the image.✂️ Replace these values to change the number of model replicas to serve, as well as the GPU resources required by each replica.
With more model replicas, more images can be generated in parallel!
NUM_REPLICAS: int = 4
if NUM_REPLICAS > ray.available_resources()["GPU"]:
print(
"Your cluster does not currently have enough resources to run with these settings. "
"Consider decreasing the number of workers, or decreasing the resources needed "
"per worker. Ignore this if your cluster auto-scales."
)
First, we define the Ray Serve Deployment, which will load a stable diffusion model and perform inference with it.
✂️ Modify this block to load your own model, and change the
generatemethod to perform your own online inference logic!
@serve.deployment(
ray_actor_options={"num_gpus": 1},
num_replicas=NUM_REPLICAS,
)
class StableDiffusionV2:
def __init__(self):
# <Replace with your own model loading logic>
import torch
from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline
model_id = "stabilityai/stable-diffusion-2"
scheduler = EulerDiscreteScheduler.from_pretrained(
model_id, subfolder="scheduler"
)
self.pipe = StableDiffusionPipeline.from_pretrained(
model_id, scheduler=scheduler, revision="fp16", torch_dtype=torch.float16
)
self.pipe = self.pipe.to("cuda")
def generate(self, prompt: str, img_size: int = 776):
# <Replace with your own model inference logic>
assert len(prompt), "prompt parameter cannot be empty"
image = self.pipe(prompt, height=img_size, width=img_size).images[0]
return image
Next, we'll define the actual API endpoint to live at /imagine.
✂️ Modify this block to change the endpoint URL, response schema, and add any post-processing logic needed from your model output!
app = FastAPI()
@serve.deployment(num_replicas=1)
@serve.ingress(app)
class APIIngress:
def __init__(self, diffusion_model_handle) -> None:
self.handle = diffusion_model_handle
@app.get(
"/imagine",
responses={200: {"content": {"image/png": {}}}},
response_class=Response,
)
async def generate(self, prompt: str, img_size: int = 776):
assert len(prompt), "prompt parameter cannot be empty"
image = await self.handle.generate.remote(prompt, img_size=img_size)
file_stream = BytesIO()
image.save(file_stream, "PNG")
return Response(content=file_stream.getvalue(), media_type="image/png")
Now, we deploy the Ray Serve application locally at http://localhost:8000!
entrypoint = APIIngress.bind(StableDiffusionV2.bind())
# Shutdown any existing Serve replicas, if they're still around.
serve.shutdown()
serve.run(entrypoint, name="serving_stable_diffusion_template")
print("Done setting up replicas! Now accepting requests...")
Next, we'll build a simple client to submit prompts as HTTP requests to the local endpoint at http://localhost:8000/imagine.
Start the client script in the next few cells, and generate your first image!
endpoint = "http://localhost:8000/imagine"
@ray.remote(num_cpus=0)
def generate_image(prompt, image_size):
req = {"prompt": prompt, "img_size": image_size}
resp = requests.get(endpoint, params=req)
return resp.content
def show_images(filenames):
fig, axs = plt.subplots(1, len(filenames), figsize=(4 * len(filenames), 4))
for i, filename in enumerate(filenames):
ax = axs if len(filenames) == 1 else axs[i]
ax.imshow(plt.imread(filename))
ax.axis("off")
plt.show()
def main(
interactive: bool = False,
prompt: str = "twin peaks sf in basquiat painting style",
num_images: int = 4,
image_size: int = 640,
):
try:
requests.get(endpoint, timeout=0.1)
except Exception as e:
raise RuntimeWarning(
"Did you setup the Ray Serve model replicas with `serve.run` "
"in a previous cell?"
) from e
generation_times = []
while True:
prompt = (
prompt
if not interactive
else input(f"\nEnter a prompt (or 'q' to quit): ")
)
if prompt.lower() == "q":
break
print("\nGenerating image(s)...\n")
start = time.time()
# Make `num_images` requests to the endpoint at once!
images = ray.get(
[generate_image.remote(prompt, image_size) for _ in range(num_images)]
)
dirname = f"{uuid.uuid4().hex[:8]}"
os.makedirs(dirname)
filenames = []
for i, image in enumerate(images):
filename = os.path.join(dirname, f"{i}.png")
with open(filename, "wb") as f:
f.write(image)
filenames.append(filename)
elapsed = time.time() - start
generation_times.append(elapsed)
print(
f"\nGenerated {len(images)} image(s) in {elapsed:.2f} seconds to "
f"the directory: {dirname}\n"
)
show_images(filenames)
if not interactive:
break
return np.mean(generation_times) if generation_times else -1
Once the stable diffusion model finishes generating your image(s), it will be included in the HTTP response body. The client saves all the images in a local directory for you to view, and they'll also show up in the notebook cell!
✂️ Replace this value to change the number of images to generate per prompt.
Each image will be generated starting from a different set of random noise, so you'll be able to see multiple options per prompt!
Try starting with
NUM_IMAGES_PER_PROMPTequal toNUM_REPLICASfrom earlier.You can choose to run this interactively, or submit a single
PROMPT.
NUM_IMAGES_PER_PROMPT: int = NUM_REPLICAS
# Control the output size: (IMAGE_SIZE, IMAGE_SIZE)
# The stable diffusion model requires `IMAGE_SIZE` to be a multiple of 8.
# NOTE: Generated image quality degrades rapidly if you reduce the size too much.
IMAGE_SIZE: int = 640
INTERACTIVE: bool = False
PROMPT = "twin peaks sf in basquiat painting style"
mean_generation_time = main(
interactive=INTERACTIVE,
prompt=PROMPT,
num_images=NUM_IMAGES_PER_PROMPT,
image_size=IMAGE_SIZE,
)
You've successfully served a stable diffusion model! You can modify this template and iterate your model deployment directly on your cluster within your Anyscale Workspace, testing with the local endpoint.
# Shut down the model replicas once you're done!
serve.shutdown()
This template used Ray Serve to serve many replicas of a stable diffusion model.
At a high level, this template showed how to:
See this getting started guide for a more detailed walkthrough of Ray Serve.