doc/source/templates/05_dreambooth_finetuning/playground.ipynb
You should use this notebook to interactively generate images, after you've already fine-tuned a stable diffusion model and have a model checkpoint available to load. See the README for instructions.
# TODO: Change this to the path of your fine-tuned model checkpoint!
# This is the $TUNED_MODEL_DIR variable defined in the run script.
TUNED_MODEL_PATH = "/tmp/model-tuned"
# TODO: Set the following variables if you fine-tuned with LoRA.
ORIG_MODEL_PATH = "/tmp/model-orig/models--CompVis--stable-diffusion-v1-4/snapshots/b95be7d6f134c3a9e62ee616f310733567f069ce/"
LORA_WEIGHTS_DIR = "/tmp/model-tuned"
First, load the model checkpoint as a HuggingFace 🤗 pipeline. Load the model onto a GPU and define a function to generate images from a text prompt.
from os import environ
import torch
from diffusers import DiffusionPipeline
from dreambooth.generate_utils import load_lora_weights, get_pipeline
pipeline = None
def on_full_ft():
global pipeline
pipeline = get_pipeline(TUNED_MODEL_PATH)
pipeline.to("cuda")
def on_lora_ft():
assert ORIG_MODEL_PATH
assert LORA_WEIGHTS_DIR
global pipeline
pipeline = get_pipeline(ORIG_MODEL_PATH, LORA_WEIGHTS_DIR)
pipeline.to("cuda")
def generate(
pipeline: DiffusionPipeline,
prompt: str,
img_size: int = 512,
num_samples: int = 1,
) -> list:
return pipeline([prompt] * num_samples, height=img_size, width=img_size).images
Now, play with your fine-tuned diffusion model through this simple GUI.
import time
import ipywidgets as widgets
from IPython.display import display, clear_output
# TODO: When giving prompts, make sure to include your subject's unique identifier,
# as well as its class name.
# For example, if your subject's unique identifier is "unqtkn" and is a dog,
# you can give the prompt "photo of a unqtkn dog on the beach".
# IPython GUI Layouts
output = widgets.Output()
toggle_buttons = widgets.ToggleButtons(
options=["Full fine-tuning","LoRA fine-tuning"],
disabled=False,
button_style='', # 'success', 'info', 'warning', 'danger' or ''
value=None,
# layout=widgets.Layout(width='100px')
)
def toggle_callback(change):
with output:
clear_output()
if change["new"] == "Full fine-tuning":
on_full_ft()
else:
on_lora_ft()
toggle_buttons.observe(toggle_callback, names="value")
input_text = widgets.Text(
value="photo of a unqtkn dog on the beach",
placeholder="",
description="Prompt:",
disabled=False,
layout=widgets.Layout(width="500px"),
)
button = widgets.Button(description="Generate!")
# Define button click event
def on_button_clicked(b):
with output:
clear_output()
print("Generating images...")
print(
"(The output image may be completely black if it's filtered by "
"HuggingFace diffusers safety checkers.)"
)
start_time = time.time()
images = generate(pipeline=pipeline, prompt=input_text.value, num_samples=2)
display(*images)
finish_time = time.time()
print(f"Completed in {finish_time - start_time} seconds.")
button.on_click(on_button_clicked)
# Display the widgets
display(toggle_buttons, widgets.HBox([input_text, button]), output)
# release memory properly
del pipeline
torch.cuda.empty_cache()