docs/diffusion/stable_diffusion/util.html
11importos12importrandom13frompathlibimportPath1415importPIL16importnumpyasnp17importtorch18fromPILimportImage1920fromlabmlimportmonit21fromlabml.loggerimportinspect22fromlabml\_nn.diffusion.stable\_diffusion.latent\_diffusionimportLatentDiffusion23fromlabml\_nn.diffusion.stable\_diffusion.model.autoencoderimportEncoder,Decoder,Autoencoder24fromlabml\_nn.diffusion.stable\_diffusion.model.clip\_embedderimportCLIPTextEmbedder25fromlabml\_nn.diffusion.stable\_diffusion.model.unetimportUNetModel
28defset\_seed(seed:int):
32random.seed(seed)33np.random.seed(seed)34torch.manual\_seed(seed)35torch.cuda.manual\_seed\_all(seed)
LatentDiffusion model38defload\_model(path:Path=None)-\>LatentDiffusion:
Initialize the autoencoder
44withmonit.section('Initialize autoencoder'):45encoder=Encoder(z\_channels=4,46in\_channels=3,47channels=128,48channel\_multipliers=[1,2,4,4],49n\_resnet\_blocks=2)5051decoder=Decoder(out\_channels=3,52z\_channels=4,53channels=128,54channel\_multipliers=[1,2,4,4],55n\_resnet\_blocks=2)5657autoencoder=Autoencoder(emb\_channels=4,58encoder=encoder,59decoder=decoder,60z\_channels=4)
Initialize the CLIP text embedder
63withmonit.section('Initialize CLIP Embedder'):64clip\_text\_embedder=CLIPTextEmbedder()
Initialize the U-Net
67withmonit.section('Initialize U-Net'):68unet\_model=UNetModel(in\_channels=4,69out\_channels=4,70channels=320,71attention\_levels=[0,1,2],72n\_res\_blocks=2,73channel\_multipliers=[1,2,4,4],74n\_heads=8,75tf\_layers=1,76d\_cond=768)
Initialize the Latent Diffusion model
79withmonit.section('Initialize Latent Diffusion model'):80model=LatentDiffusion(linear\_start=0.00085,81linear\_end=0.0120,82n\_steps=1000,83latent\_scaling\_factor=0.18215,8485autoencoder=autoencoder,86clip\_embedder=clip\_text\_embedder,87unet\_model=unet\_model)
Load the checkpoint
90withmonit.section(f"Loading model from {path}"):91checkpoint=torch.load(path,map\_location="cpu")
Set model state
94withmonit.section('Load state'):95missing\_keys,extra\_keys=model.load\_state\_dict(checkpoint["state\_dict"],strict=False)
Debugging output
98inspect(global\_step=checkpoint.get('global\_step',-1),missing\_keys=missing\_keys,extra\_keys=extra\_keys,99\_expand=True)
102model.eval()103returnmodel
This loads an image from a file and returns a PyTorch tensor.
path is the path of the image106defload\_img(path:str):
Open Image
115image=Image.open(path).convert("RGB")
Get image size
117w,h=image.size
Resize to a multiple of 32
119w=w-w%32120h=h-h%32121image=image.resize((w,h),resample=PIL.Image.LANCZOS)
Convert to numpy and map to [-1, 1] for [0, 255]
123image=np.array(image).astype(np.float32)\*(2./255.0)-1
Transpose to shape [batch_size, channels, height, width]
125image=image[None].transpose(0,3,1,2)
Convert to torch
127returntorch.from\_numpy(image)
images is the tensor with images of shape [batch_size, channels, height, width]dest_path is the folder to save images inprefix is the prefix to add to file namesimg_format is the image format130defsave\_images(images:torch.Tensor,dest\_path:str,prefix:str='',img\_format:str='jpeg'):
Create the destination folder
141os.makedirs(dest\_path,exist\_ok=True)
Map images to [0, 1] space and clip
144images=torch.clamp((images+1.0)/2.0,min=0.0,max=1.0)
Transpose to [batch_size, height, width, channels] and convert to numpy
146images=images.cpu().permute(0,2,3,1).numpy()
Save images
149fori,imginenumerate(images):150img=Image.fromarray((255.\*img).astype(np.uint8))151img.save(os.path.join(dest\_path,f"{prefix}{i:05}.{img\_format}"),format=img\_format)