docs/diffusion/stable_diffusion/latent_diffusion.html
Latent diffusion models use an auto-encoder to map between image space and latent space. The diffusion model works on the latent space, which makes it a lot easier to train. It is based on paper High-Resolution Image Synthesis with Latent Diffusion Models.
They use a pre-trained auto-encoder and train the diffusion U-Net on the latent space of the pre-trained auto-encoder.
For a simpler diffusion implementation refer to our DDPM implementation. We use same notations for αt, βt schedules, etc.
24fromtypingimportList2526importtorch27importtorch.nnasnn2829fromlabml\_nn.diffusion.stable\_diffusion.model.autoencoderimportAutoencoder30fromlabml\_nn.diffusion.stable\_diffusion.model.clip\_embedderimportCLIPTextEmbedder31fromlabml\_nn.diffusion.stable\_diffusion.model.unetimportUNetModel
This is an empty wrapper class around the U-Net. We keep this to have the same model structure as CompVis/stable-diffusion so that we do not have to map the checkpoint weights explicitly.
34classDiffusionWrapper(nn.Module):
42def\_\_init\_\_(self,diffusion\_model:UNetModel):43super().\_\_init\_\_()44self.diffusion\_model=diffusion\_model
46defforward(self,x:torch.Tensor,time\_steps:torch.Tensor,context:torch.Tensor):47returnself.diffusion\_model(x,time\_steps,context)
This contains following components:
50classLatentDiffusion(nn.Module):
60model:DiffusionWrapper61first\_stage\_model:Autoencoder62cond\_stage\_model:CLIPTextEmbedder
unet_model is the U-Net that predicts noise ϵcond(xt,c), in latent spaceautoencoder is the AutoEncoderclip_embedder is the CLIP embeddings generatorlatent_scaling_factor is the scaling factor for the latent space. The encodings of the autoencoder are scaled by this before feeding into the U-Net.n_steps is the number of diffusion steps T.linear_start is the start of the β schedule.linear_end is the end of the β schedule.64def\_\_init\_\_(self,65unet\_model:UNetModel,66autoencoder:Autoencoder,67clip\_embedder:CLIPTextEmbedder,68latent\_scaling\_factor:float,69n\_steps:int,70linear\_start:float,71linear\_end:float,72):
84super().\_\_init\_\_()
Wrap the U-Net to keep the same model structure as CompVis/stable-diffusion.
87self.model=DiffusionWrapper(unet\_model)
Auto-encoder and scaling factor
89self.first\_stage\_model=autoencoder90self.latent\_scaling\_factor=latent\_scaling\_factor
92self.cond\_stage\_model=clip\_embedder
Number of steps T
95self.n\_steps=n\_steps
β schedule
98beta=torch.linspace(linear\_start\*\*0.5,linear\_end\*\*0.5,n\_steps,dtype=torch.float64)\*\*299self.beta=nn.Parameter(beta.to(torch.float32),requires\_grad=False)
αt=1−βt
101alpha=1.-beta
αtˉ=∏s=1tαs
103alpha\_bar=torch.cumprod(alpha,dim=0)104self.alpha\_bar=nn.Parameter(alpha\_bar.to(torch.float32),requires\_grad=False)
106@property107defdevice(self):
111returnnext(iter(self.model.parameters())).device
113defget\_text\_conditioning(self,prompts:List[str]):
117returnself.cond\_stage\_model(prompts)
The encoder output is a distribution. We sample from that and multiply by the scaling factor.
119defautoencoder\_encode(self,image:torch.Tensor):
126returnself.latent\_scaling\_factor\*self.first\_stage\_model.encode(image).sample()
We scale down by the scaling factor and then decode.
128defautoencoder\_decode(self,z:torch.Tensor):
134returnself.first\_stage\_model.decode(z/self.latent\_scaling\_factor)
Predict noise given the latent representation xt, time step t, and the conditioning context c.
ϵcond(xt,c)
136defforward(self,x:torch.Tensor,t:torch.Tensor,context:torch.Tensor):
145returnself.model(x,t,context)