Back to Annotated Deep Learning Paper Implementations

Sampling algorithms for stable diffusion

docs/diffusion/stable_diffusion/sampler/index.html

latest4.5 KB
Original Source

homediffusionstable_diffusionsampler

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/diffusion/stable_diffusion/sampler/ init.py)

#

Sampling algorithms for stable diffusion

We have implemented the following sampling algorithms:

18fromtypingimportOptional,List1920importtorch2122fromlabml\_nn.diffusion.stable\_diffusion.latent\_diffusionimportLatentDiffusion

#

Base class for sampling algorithms

25classDiffusionSampler:

#

29model:LatentDiffusion

#

  • model is the model to predict noise ϵcond​(xt​,c)
31def\_\_init\_\_(self,model:LatentDiffusion):

#

35super().\_\_init\_\_()

#

Set the model ϵcond​(xt​,c)

37self.model=model

#

Get number of steps the model was trained with T

39self.n\_steps=model.n\_steps

#

Get ϵ(xt​,c)

  • x is xt​ of shape [batch_size, channels, height, width]
  • t is t of shape [batch_size]
  • c is the conditional embeddings c of shape [batch_size, emb_size]
  • uncond_scale is the unconditional guidance scale s. This is used for ϵθ​(xt​,c)=sϵcond​(xt​,c)+(s−1)ϵcond​(xt​,cu​)
  • uncond_cond is the conditional embedding for empty prompt cu​
41defget\_eps(self,x:torch.Tensor,t:torch.Tensor,c:torch.Tensor,\*,42uncond\_scale:float,uncond\_cond:Optional[torch.Tensor]):

#

When the scale s=1 ϵθ​(xt​,c)=ϵcond​(xt​,c)

55ifuncond\_condisNoneoruncond\_scale==1.:56returnself.model(x,t,c)

#

Duplicate xt​ and t

59x\_in=torch.cat([x]\*2)60t\_in=torch.cat([t]\*2)

#

Concatenated c and cu​

62c\_in=torch.cat([uncond\_cond,c])

#

Get ϵcond​(xt​,c) and ϵcond​(xt​,cu​)

64e\_t\_uncond,e\_t\_cond=self.model(x\_in,t\_in,c\_in).chunk(2)

#

Calculate ϵθ​(xt​,c)=sϵcond​(xt​,c)+(s−1)ϵcond​(xt​,cu​)

67e\_t=e\_t\_uncond+uncond\_scale\*(e\_t\_cond-e\_t\_uncond)

#

70returne\_t

#

Sampling Loop

  • shape is the shape of the generated images in the form [batch_size, channels, height, width]
  • cond is the conditional embeddings c
  • temperature is the noise temperature (random noise gets multiplied by this)
  • x_last is xT​. If not provided random noise will be used.
  • uncond_scale is the unconditional guidance scale s. This is used for ϵθ​(xt​,c)=sϵcond​(xt​,c)+(s−1)ϵcond​(xt​,cu​)
  • uncond_cond is the conditional embedding for empty prompt cu​
  • skip_steps is the number of time steps to skip.
72defsample(self,73shape:List[int],74cond:torch.Tensor,75repeat\_noise:bool=False,76temperature:float=1.,77x\_last:Optional[torch.Tensor]=None,78uncond\_scale:float=1.,79uncond\_cond:Optional[torch.Tensor]=None,80skip\_steps:int=0,81):

#

95raiseNotImplementedError()

#

Painting Loop

  • x is xT′​ of shape [batch_size, channels, height, width]
  • cond is the conditional embeddings c
  • t_start is the sampling step to start from, T′
  • orig is the original image in latent page which we are in paining.
  • mask is the mask to keep the original image.
  • orig_noise is fixed noise to be added to the original image.
  • uncond_scale is the unconditional guidance scale s. This is used for ϵθ​(xt​,c)=sϵcond​(xt​,c)+(s−1)ϵcond​(xt​,cu​)
  • uncond_cond is the conditional embedding for empty prompt cu​
97defpaint(self,x:torch.Tensor,cond:torch.Tensor,t\_start:int,\*,98orig:Optional[torch.Tensor]=None,99mask:Optional[torch.Tensor]=None,orig\_noise:Optional[torch.Tensor]=None,100uncond\_scale:float=1.,101uncond\_cond:Optional[torch.Tensor]=None,102):

#

116raiseNotImplementedError()

#

Sample from q(xt​∣x0​)

  • x0 is x0​ of shape [batch_size, channels, height, width]
  • index is the time step t index
  • noise is the noise, ϵ
118defq\_sample(self,x0:torch.Tensor,index:int,noise:Optional[torch.Tensor]=None):

#

126raiseNotImplementedError()

labml.ai