Back to Annotated Deep Learning Paper Implementations

Latent Diffusion Models

docs/diffusion/stable_diffusion/latent_diffusion.html

latest5.2 KB
Original Source

homediffusionstable_diffusion

View code on Github

#

Latent Diffusion Models

Latent diffusion models use an auto-encoder to map between image space and latent space. The diffusion model works on the latent space, which makes it a lot easier to train. It is based on paper High-Resolution Image Synthesis with Latent Diffusion Models.

They use a pre-trained auto-encoder and train the diffusion U-Net on the latent space of the pre-trained auto-encoder.

For a simpler diffusion implementation refer to our DDPM implementation. We use same notations for αt​, βt​ schedules, etc.

24fromtypingimportList2526importtorch27importtorch.nnasnn2829fromlabml\_nn.diffusion.stable\_diffusion.model.autoencoderimportAutoencoder30fromlabml\_nn.diffusion.stable\_diffusion.model.clip\_embedderimportCLIPTextEmbedder31fromlabml\_nn.diffusion.stable\_diffusion.model.unetimportUNetModel

#

This is an empty wrapper class around the U-Net. We keep this to have the same model structure as CompVis/stable-diffusion so that we do not have to map the checkpoint weights explicitly.

34classDiffusionWrapper(nn.Module):

#

42def\_\_init\_\_(self,diffusion\_model:UNetModel):43super().\_\_init\_\_()44self.diffusion\_model=diffusion\_model

#

46defforward(self,x:torch.Tensor,time\_steps:torch.Tensor,context:torch.Tensor):47returnself.diffusion\_model(x,time\_steps,context)

#

Latent diffusion model

This contains following components:

50classLatentDiffusion(nn.Module):

#

60model:DiffusionWrapper61first\_stage\_model:Autoencoder62cond\_stage\_model:CLIPTextEmbedder

#

  • unet_model is the U-Net that predicts noise ϵcond​(xt​,c), in latent space
  • autoencoder is the AutoEncoder
  • clip_embedder is the CLIP embeddings generator
  • latent_scaling_factor is the scaling factor for the latent space. The encodings of the autoencoder are scaled by this before feeding into the U-Net.
  • n_steps is the number of diffusion steps T.
  • linear_start is the start of the β schedule.
  • linear_end is the end of the β schedule.
64def\_\_init\_\_(self,65unet\_model:UNetModel,66autoencoder:Autoencoder,67clip\_embedder:CLIPTextEmbedder,68latent\_scaling\_factor:float,69n\_steps:int,70linear\_start:float,71linear\_end:float,72):

#

84super().\_\_init\_\_()

#

Wrap the U-Net to keep the same model structure as CompVis/stable-diffusion.

87self.model=DiffusionWrapper(unet\_model)

#

Auto-encoder and scaling factor

89self.first\_stage\_model=autoencoder90self.latent\_scaling\_factor=latent\_scaling\_factor

#

CLIP embeddings generator

92self.cond\_stage\_model=clip\_embedder

#

Number of steps T

95self.n\_steps=n\_steps

#

β schedule

98beta=torch.linspace(linear\_start\*\*0.5,linear\_end\*\*0.5,n\_steps,dtype=torch.float64)\*\*299self.beta=nn.Parameter(beta.to(torch.float32),requires\_grad=False)

#

αt​=1−βt​

101alpha=1.-beta

#

αt​ˉ​=∏s=1t​αs​

103alpha\_bar=torch.cumprod(alpha,dim=0)104self.alpha\_bar=nn.Parameter(alpha\_bar.to(torch.float32),requires\_grad=False)

#

Get model device

106@property107defdevice(self):

#

111returnnext(iter(self.model.parameters())).device

#

Get CLIP embeddings for a list of text prompts

113defget\_text\_conditioning(self,prompts:List[str]):

#

117returnself.cond\_stage\_model(prompts)

#

Get scaled latent space representation of the image

The encoder output is a distribution. We sample from that and multiply by the scaling factor.

119defautoencoder\_encode(self,image:torch.Tensor):

#

126returnself.latent\_scaling\_factor\*self.first\_stage\_model.encode(image).sample()

#

Get image from the latent representation

We scale down by the scaling factor and then decode.

128defautoencoder\_decode(self,z:torch.Tensor):

#

134returnself.first\_stage\_model.decode(z/self.latent\_scaling\_factor)

#

Predict noise

Predict noise given the latent representation xt​, time step t, and the conditioning context c.

ϵcond​(xt​,c)

136defforward(self,x:torch.Tensor,t:torch.Tensor,context:torch.Tensor):

#

145returnself.model(x,t,context)

labml.ai