Back to Peft

Will error if the minimal version of diffusers is not installed. Remove at your own risks.

examples/hra_dreambooth/dreambooth_inference.ipynb

0.19.12.7 KB
Original Source
python
import os
from PIL import Image

import torch
from accelerate.logging import get_logger
from diffusers import StableDiffusionPipeline
from diffusers.utils import check_min_version

from peft import PeftModel


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")

logger = get_logger(__name__)

MODEL_NAME = "stabilityai/stable-diffusion-2-1"

PEFT_TYPE="hra"
HRA_R=8
SELECTED_SUBJECT="backpack"
EPOCH_IDX = 1000

PROJECT_NAME=f"dreambooth_{PEFT_TYPE}"
RUN_NAME=f"{SELECTED_SUBJECT}_{PEFT_TYPE}_{HRA_R}"
OUTPUT_DIR=f"./data/output/{PEFT_TYPE}"
python
def get_hra_sd_pipeline(
    ckpt_dir, base_model_name_or_path=None, epoch=int, dtype=torch.float32, device="cuda", adapter_name="default"
):

    if base_model_name_or_path is None:
        raise ValueError("Please specify the base model name or path")

    pipe = StableDiffusionPipeline.from_pretrained(
        base_model_name_or_path, torch_dtype=dtype, requires_safety_checker=False
    ).to(device)
    
    load_adapter(pipe, ckpt_dir, epoch, adapter_name)

    if dtype in (torch.float16, torch.bfloat16):
        pipe.unet.half()
        pipe.text_encoder.half()

    pipe.to(device)
    return pipe


def load_adapter(pipe, ckpt_dir, epoch, adapter_name="default"):
    
    unet_sub_dir = os.path.join(ckpt_dir, f"unet/{epoch}", adapter_name)
    text_encoder_sub_dir = os.path.join(ckpt_dir, f"text_encoder/{epoch}", adapter_name)
    
    if isinstance(pipe.unet, PeftModel):
        pipe.unet.load_adapter(unet_sub_dir, adapter_name=adapter_name)
    else:
        pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
        
    if os.path.exists(text_encoder_sub_dir):
        if isinstance(pipe.text_encoder, PeftModel):
            pipe.text_encoder.load_adapter(text_encoder_sub_dir, adapter_name=adapter_name)
        else:
            pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name)
    

def set_adapter(pipe, adapter_name):
    pipe.unet.set_adapter(adapter_name)
    if isinstance(pipe.text_encoder, PeftModel):
        pipe.text_encoder.set_adapter(adapter_name)
python
prompt = "a purple qwe backpack."
negative_prompt = "low quality, blurry, unfinished"
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
python
%%time
pipe = get_hra_sd_pipeline(OUTPUT_DIR, MODEL_NAME, EPOCH_IDX, adapter_name=RUN_NAME, device=device)
python
%%time
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5, negative_prompt=negative_prompt).images[0]
image
python
# This is an example.
example_image = Image.open("./a_purple_qwe_backpack.png")
example_image