docs/diffusion/stable_diffusion/sampler/index.html
homediffusionstable_diffusionsampler
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/diffusion/stable_diffusion/sampler/ init.py)
We have implemented the following sampling algorithms:
18fromtypingimportOptional,List1920importtorch2122fromlabml\_nn.diffusion.stable\_diffusion.latent\_diffusionimportLatentDiffusion
25classDiffusionSampler:
29model:LatentDiffusion
model is the model to predict noise ϵcond(xt,c)31def\_\_init\_\_(self,model:LatentDiffusion):
35super().\_\_init\_\_()
Set the model ϵcond(xt,c)
37self.model=model
Get number of steps the model was trained with T
39self.n\_steps=model.n\_steps
x is xt of shape [batch_size, channels, height, width]t is t of shape [batch_size]c is the conditional embeddings c of shape [batch_size, emb_size]uncond_scale is the unconditional guidance scale s. This is used for ϵθ(xt,c)=sϵcond(xt,c)+(s−1)ϵcond(xt,cu)uncond_cond is the conditional embedding for empty prompt cu41defget\_eps(self,x:torch.Tensor,t:torch.Tensor,c:torch.Tensor,\*,42uncond\_scale:float,uncond\_cond:Optional[torch.Tensor]):
When the scale s=1 ϵθ(xt,c)=ϵcond(xt,c)
55ifuncond\_condisNoneoruncond\_scale==1.:56returnself.model(x,t,c)
Duplicate xt and t
59x\_in=torch.cat([x]\*2)60t\_in=torch.cat([t]\*2)
Concatenated c and cu
62c\_in=torch.cat([uncond\_cond,c])
Get ϵcond(xt,c) and ϵcond(xt,cu)
64e\_t\_uncond,e\_t\_cond=self.model(x\_in,t\_in,c\_in).chunk(2)
Calculate ϵθ(xt,c)=sϵcond(xt,c)+(s−1)ϵcond(xt,cu)
67e\_t=e\_t\_uncond+uncond\_scale\*(e\_t\_cond-e\_t\_uncond)
70returne\_t
shape is the shape of the generated images in the form [batch_size, channels, height, width]cond is the conditional embeddings ctemperature is the noise temperature (random noise gets multiplied by this)x_last is xT. If not provided random noise will be used.uncond_scale is the unconditional guidance scale s. This is used for ϵθ(xt,c)=sϵcond(xt,c)+(s−1)ϵcond(xt,cu)uncond_cond is the conditional embedding for empty prompt cuskip_steps is the number of time steps to skip.72defsample(self,73shape:List[int],74cond:torch.Tensor,75repeat\_noise:bool=False,76temperature:float=1.,77x\_last:Optional[torch.Tensor]=None,78uncond\_scale:float=1.,79uncond\_cond:Optional[torch.Tensor]=None,80skip\_steps:int=0,81):
95raiseNotImplementedError()
x is xT′ of shape [batch_size, channels, height, width]cond is the conditional embeddings ct_start is the sampling step to start from, T′orig is the original image in latent page which we are in paining.mask is the mask to keep the original image.orig_noise is fixed noise to be added to the original image.uncond_scale is the unconditional guidance scale s. This is used for ϵθ(xt,c)=sϵcond(xt,c)+(s−1)ϵcond(xt,cu)uncond_cond is the conditional embedding for empty prompt cu97defpaint(self,x:torch.Tensor,cond:torch.Tensor,t\_start:int,\*,98orig:Optional[torch.Tensor]=None,99mask:Optional[torch.Tensor]=None,orig\_noise:Optional[torch.Tensor]=None,100uncond\_scale:float=1.,101uncond\_cond:Optional[torch.Tensor]=None,102):
116raiseNotImplementedError()
x0 is x0 of shape [batch_size, channels, height, width]index is the time step t indexnoise is the noise, ϵ118defq\_sample(self,x0:torch.Tensor,index:int,noise:Optional[torch.Tensor]=None):
126raiseNotImplementedError()