Back to Annotated Deep Learning Paper Implementations

Utility functions for stable diffusion

docs/diffusion/stable_diffusion/util.html

latest4.5 KB
Original Source

homediffusionstable_diffusion

View code on Github

#

Utility functions for stable diffusion

11importos12importrandom13frompathlibimportPath1415importPIL16importnumpyasnp17importtorch18fromPILimportImage1920fromlabmlimportmonit21fromlabml.loggerimportinspect22fromlabml\_nn.diffusion.stable\_diffusion.latent\_diffusionimportLatentDiffusion23fromlabml\_nn.diffusion.stable\_diffusion.model.autoencoderimportEncoder,Decoder,Autoencoder24fromlabml\_nn.diffusion.stable\_diffusion.model.clip\_embedderimportCLIPTextEmbedder25fromlabml\_nn.diffusion.stable\_diffusion.model.unetimportUNetModel

#

Set random seeds

28defset\_seed(seed:int):

#

32random.seed(seed)33np.random.seed(seed)34torch.manual\_seed(seed)35torch.cuda.manual\_seed\_all(seed)

#

Load LatentDiffusion model

38defload\_model(path:Path=None)-\>LatentDiffusion:

#

Initialize the autoencoder

44withmonit.section('Initialize autoencoder'):45encoder=Encoder(z\_channels=4,46in\_channels=3,47channels=128,48channel\_multipliers=[1,2,4,4],49n\_resnet\_blocks=2)5051decoder=Decoder(out\_channels=3,52z\_channels=4,53channels=128,54channel\_multipliers=[1,2,4,4],55n\_resnet\_blocks=2)5657autoencoder=Autoencoder(emb\_channels=4,58encoder=encoder,59decoder=decoder,60z\_channels=4)

#

Initialize the CLIP text embedder

63withmonit.section('Initialize CLIP Embedder'):64clip\_text\_embedder=CLIPTextEmbedder()

#

Initialize the U-Net

67withmonit.section('Initialize U-Net'):68unet\_model=UNetModel(in\_channels=4,69out\_channels=4,70channels=320,71attention\_levels=[0,1,2],72n\_res\_blocks=2,73channel\_multipliers=[1,2,4,4],74n\_heads=8,75tf\_layers=1,76d\_cond=768)

#

Initialize the Latent Diffusion model

79withmonit.section('Initialize Latent Diffusion model'):80model=LatentDiffusion(linear\_start=0.00085,81linear\_end=0.0120,82n\_steps=1000,83latent\_scaling\_factor=0.18215,8485autoencoder=autoencoder,86clip\_embedder=clip\_text\_embedder,87unet\_model=unet\_model)

#

Load the checkpoint

90withmonit.section(f"Loading model from {path}"):91checkpoint=torch.load(path,map\_location="cpu")

#

Set model state

94withmonit.section('Load state'):95missing\_keys,extra\_keys=model.load\_state\_dict(checkpoint["state\_dict"],strict=False)

#

Debugging output

98inspect(global\_step=checkpoint.get('global\_step',-1),missing\_keys=missing\_keys,extra\_keys=extra\_keys,99\_expand=True)

#

102model.eval()103returnmodel

#

Load an image

This loads an image from a file and returns a PyTorch tensor.

  • path is the path of the image
106defload\_img(path:str):

#

Open Image

115image=Image.open(path).convert("RGB")

#

Get image size

117w,h=image.size

#

Resize to a multiple of 32

119w=w-w%32120h=h-h%32121image=image.resize((w,h),resample=PIL.Image.LANCZOS)

#

Convert to numpy and map to [-1, 1] for [0, 255]

123image=np.array(image).astype(np.float32)\*(2./255.0)-1

#

Transpose to shape [batch_size, channels, height, width]

125image=image[None].transpose(0,3,1,2)

#

Convert to torch

127returntorch.from\_numpy(image)

#

Save a images

  • images is the tensor with images of shape [batch_size, channels, height, width]
  • dest_path is the folder to save images in
  • prefix is the prefix to add to file names
  • img_format is the image format
130defsave\_images(images:torch.Tensor,dest\_path:str,prefix:str='',img\_format:str='jpeg'):

#

Create the destination folder

141os.makedirs(dest\_path,exist\_ok=True)

#

Map images to [0, 1] space and clip

144images=torch.clamp((images+1.0)/2.0,min=0.0,max=1.0)

#

Transpose to [batch_size, height, width, channels] and convert to numpy

146images=images.cpu().permute(0,2,3,1).numpy()

#

Save images

149fori,imginenumerate(images):150img=Image.fromarray((255.\*img).astype(np.uint8))151img.save(os.path.join(dest\_path,f"{prefix}{i:05}.{img\_format}"),format=img\_format)

labml.ai