Back to Peft

Will error if the minimal version of diffusers is not installed. Remove at your own risks.

examples/boft_dreambooth/dreambooth_inference.ipynb

0.19.13.3 KB
Original Source
python
import os

import torch
from accelerate.logging import get_logger
from diffusers import StableDiffusionPipeline
from diffusers.utils import check_min_version

from peft import PeftModel

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")

logger = get_logger(__name__)

MODEL_NAME = "stabilityai/stable-diffusion-2-1"
# MODEL_NAME="runwayml/stable-diffusion-v1-5"

PEFT_TYPE="boft"
BLOCK_NUM=8
BLOCK_SIZE=0
N_BUTTERFLY_FACTOR=1
SELECTED_SUBJECT="backpack"
EPOCH_IDX = 200

PROJECT_NAME=f"dreambooth_{PEFT_TYPE}"
RUN_NAME=f"{SELECTED_SUBJECT}_{PEFT_TYPE}_{BLOCK_NUM}{BLOCK_SIZE}{N_BUTTERFLY_FACTOR}"
OUTPUT_DIR=f"./data/output/{PEFT_TYPE}"
python
def get_boft_sd_pipeline(
    ckpt_dir, base_model_name_or_path=None, epoch=int, dtype=torch.float32, device="auto", adapter_name="default"
):
    if device == "auto":
        device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"

    if base_model_name_or_path is None:
        raise ValueError("Please specify the base model name or path")

    pipe = StableDiffusionPipeline.from_pretrained(
        base_model_name_or_path, torch_dtype=dtype, requires_safety_checker=False
    ).to(device)
    
    load_adapter(pipe, ckpt_dir, epoch, adapter_name)

    if dtype in (torch.float16, torch.bfloat16):
        pipe.unet.half()
        pipe.text_encoder.half()

    pipe.to(device)
    return pipe


def load_adapter(pipe, ckpt_dir, epoch, adapter_name="default"):
    
    unet_sub_dir = os.path.join(ckpt_dir, f"unet/{epoch}", adapter_name)
    text_encoder_sub_dir = os.path.join(ckpt_dir, f"text_encoder/{epoch}", adapter_name)
    
    if isinstance(pipe.unet, PeftModel):
        pipe.unet.load_adapter(unet_sub_dir, adapter_name=adapter_name)
    else:
        pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
        
    if os.path.exists(text_encoder_sub_dir):
        if isinstance(pipe.text_encoder, PeftModel):
            pipe.text_encoder.load_adapter(text_encoder_sub_dir, adapter_name=adapter_name)
        else:
            pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name)
    

def set_adapter(pipe, adapter_name):
    pipe.unet.set_adapter(adapter_name)
    if isinstance(pipe.text_encoder, PeftModel):
        pipe.text_encoder.set_adapter(adapter_name)
python
prompt = "a photo of sks backpack on a wooden floor"
negative_prompt = "low quality, blurry, unfinished"
python
%%time
pipe = get_boft_sd_pipeline(OUTPUT_DIR, MODEL_NAME, EPOCH_IDX, adapter_name=RUN_NAME)
python
%%time
image = pipe(prompt, num_inference_steps=50, guidance_scale=7, negative_prompt=negative_prompt).images[0]
image
python
# load and reset another adapter
# WARNING: requires training DreamBooth with `boft_bias=None`

SELECTED_SUBJECT="dog"
EPOCH_IDX = 200
RUN_NAME=f"{SELECTED_SUBJECT}_{PEFT_TYPE}_{BLOCK_NUM}{BLOCK_SIZE}{N_BUTTERFLY_FACTOR}"

load_adapter(pipe, OUTPUT_DIR, epoch=EPOCH_IDX, adapter_name=RUN_NAME)
set_adapter(pipe, adapter_name=RUN_NAME)
python
%%time
prompt = "a photo of sks dog running on the beach"
negative_prompt = "low quality, blurry, unfinished"
image = pipe(prompt, num_inference_steps=50, guidance_scale=7, negative_prompt=negative_prompt).images[0]
image