docs/diffusion/stable_diffusion/sampler/ddim.html
homediffusionstable_diffusionsampler
This implements DDIM sampling from the paper Denoising Diffusion Implicit Models
16fromtypingimportOptional,List1718importnumpyasnp19importtorch2021fromlabmlimportmonit22fromlabml\_nn.diffusion.stable\_diffusion.latent\_diffusionimportLatentDiffusion23fromlabml\_nn.diffusion.stable\_diffusion.samplerimportDiffusionSampler
This extends the DiffusionSampler base class.
DDIM samples images by repeatedly removing noise by sampling step by step using,
xτi−1=ατi−1(ατixτi−1−ατiϵθ(xτi))+1−ατi−1−στi2⋅ϵθ(xτi)+στiϵτi
where ϵτi is random noise, τ is a subsequence of [1,2,…,T] of length S, and στi=η1−ατi1−ατi−11−ατi−1ατi
Note that, αt in DDIM paper refers to αtˉ from DDPM.
26classDDIMSampler(DiffusionSampler):
52model:LatentDiffusion
model is the model to predict noise ϵcond(xt,c)n_steps is the number of DDIM sampling steps, Sddim_discretize specifies how to extract τ from [1,2,…,T]. It can be either uniform or quad .ddim_eta is η used to calculate στi. η=0 makes the sampling process deterministic.54def\_\_init\_\_(self,model:LatentDiffusion,n\_steps:int,ddim\_discretize:str="uniform",ddim\_eta:float=0.):
63super().\_\_init\_\_(model)
Number of steps, T
65self.n\_steps=model.n\_steps
Calculate τ to be uniformly distributed across [1,2,…,T]
68ifddim\_discretize=='uniform':69c=self.n\_steps//n\_steps70self.time\_steps=np.asarray(list(range(0,self.n\_steps,c)))+1
Calculate τ to be quadratically distributed across [1,2,…,T]
72elifddim\_discretize=='quad':73self.time\_steps=((np.linspace(0,np.sqrt(self.n\_steps\*.8),n\_steps))\*\*2).astype(int)+174else:75raiseNotImplementedError(ddim\_discretize)7677withtorch.no\_grad():
Get αtˉ
79alpha\_bar=self.model.alpha\_bar
ατi
82self.ddim\_alpha=alpha\_bar[self.time\_steps].clone().to(torch.float32)
ατi
84self.ddim\_alpha\_sqrt=torch.sqrt(self.ddim\_alpha)
ατi−1
86self.ddim\_alpha\_prev=torch.cat([alpha\_bar[0:1],alpha\_bar[self.time\_steps[:-1]]])
στi=η1−ατi1−ατi−11−ατi−1ατi
91self.ddim\_sigma=(ddim\_eta\*92((1-self.ddim\_alpha\_prev)/(1-self.ddim\_alpha)\*93(1-self.ddim\_alpha/self.ddim\_alpha\_prev))\*\*.5)
1−ατi
96self.ddim\_sqrt\_one\_minus\_alpha=(1.-self.ddim\_alpha)\*\*.5
shape is the shape of the generated images in the form [batch_size, channels, height, width]cond is the conditional embeddings ctemperature is the noise temperature (random noise gets multiplied by this)x_last is xτS. If not provided random noise will be used.uncond_scale is the unconditional guidance scale s. This is used for ϵθ(xt,c)=sϵcond(xt,c)+(s−1)ϵcond(xt,cu)uncond_cond is the conditional embedding for empty prompt cuskip_steps is the number of time steps to skip i′. We start sampling from S−i′. And x_last is then xτS−i′.[email protected]\_grad()99defsample(self,100shape:List[int],101cond:torch.Tensor,102repeat\_noise:bool=False,103temperature:float=1.,104x\_last:Optional[torch.Tensor]=None,105uncond\_scale:float=1.,106uncond\_cond:Optional[torch.Tensor]=None,107skip\_steps:int=0,108):
Get device and batch size
125device=self.model.device126bs=shape[0]
Get xτS
129x=x\_lastifx\_lastisnotNoneelsetorch.randn(shape,device=device)
Time steps to sample at τS−i′,τS−i′−1,…,τ1
132time\_steps=np.flip(self.time\_steps)[skip\_steps:]133134fori,stepinmonit.enum('Sample',time\_steps):
Index i in the list [τ1,τ2,…,τS]
136index=len(time\_steps)-i-1
Time step τi
138ts=x.new\_full((bs,),step,dtype=torch.long)
Sample xτi−1
141x,pred\_x0,e\_t=self.p\_sample(x,cond,ts,step,index=index,142repeat\_noise=repeat\_noise,143temperature=temperature,144uncond\_scale=uncond\_scale,145uncond\_cond=uncond\_cond)
Return x0
148returnx
x is xτi of shape [batch_size, channels, height, width]c is the conditional embeddings c of shape [batch_size, emb_size]t is τi of shape [batch_size]step is the step τi as an integerindex is index i in the list [τ1,τ2,…,τS]repeat_noise specified whether the noise should be same for all samples in the batchtemperature is the noise temperature (random noise gets multiplied by this)uncond_scale is the unconditional guidance scale s. This is used for ϵθ(xt,c)=sϵcond(xt,c)+(s−1)ϵcond(xt,cu)uncond_cond is the conditional embedding for empty prompt cu[email protected]\_grad()151defp\_sample(self,x:torch.Tensor,c:torch.Tensor,t:torch.Tensor,step:int,index:int,\*,152repeat\_noise:bool=False,153temperature:float=1.,154uncond\_scale:float=1.,155uncond\_cond:Optional[torch.Tensor]=None):
Get ϵθ(xτi)
172e\_t=self.get\_eps(x,t,c,173uncond\_scale=uncond\_scale,174uncond\_cond=uncond\_cond)
Calculate xτi−1 and predicted x0
177x\_prev,pred\_x0=self.get\_x\_prev\_and\_pred\_x0(e\_t,index,x,178temperature=temperature,179repeat\_noise=repeat\_noise)
182returnx\_prev,pred\_x0,e\_t
184defget\_x\_prev\_and\_pred\_x0(self,e\_t:torch.Tensor,index:int,x:torch.Tensor,\*,185temperature:float,186repeat\_noise:bool):
ατi
192alpha=self.ddim\_alpha[index]
ατi−1
194alpha\_prev=self.ddim\_alpha\_prev[index]
στi
196sigma=self.ddim\_sigma[index]
1−ατi
198sqrt\_one\_minus\_alpha=self.ddim\_sqrt\_one\_minus\_alpha[index]
Current prediction for x0, ατixτi−1−ατiϵθ(xτi)
202pred\_x0=(x-sqrt\_one\_minus\_alpha\*e\_t)/(alpha\*\*0.5)
Direction pointing to xt 1−ατi−1−στi2⋅ϵθ(xτi)
205dir\_xt=(1.-alpha\_prev-sigma\*\*2).sqrt()\*e\_t
No noise is added, when η=0
208ifsigma==0.:209noise=0.
If same noise is used for all samples in the batch
211elifrepeat\_noise:212noise=torch.randn((1,\*x.shape[1:]),device=x.device)
Different noise for each sample
214else:215noise=torch.randn(x.shape,device=x.device)
Multiply noise by the temperature
218noise=noise\*temperature
# xτi−1=ατi−1(ατixτi−1−ατiϵθ(xτi))+1−ατi−1−στi2⋅ϵθ(xτi)+στiϵτi
227x\_prev=(alpha\_prev\*\*0.5)\*pred\_x0+dir\_xt+sigma\*noise
230returnx\_prev,pred\_x0
qσ,τ(xt∣x0)=N(xt;ατix0,(1−ατi)I)
x0 is x0 of shape [batch_size, channels, height, width]index is the time step τi index inoise is the noise, ϵ[email protected]\_grad()233defq\_sample(self,x0:torch.Tensor,index:int,noise:Optional[torch.Tensor]=None):
Random noise, if noise is not specified
246ifnoiseisNone:247noise=torch.randn\_like(x0)
Sample from qσ,τ(xt∣x0)=N(xt;ατix0,(1−ατi)I)
252returnself.ddim\_alpha\_sqrt[index]\*x0+self.ddim\_sqrt\_one\_minus\_alpha[index]\*noise
x is xS′ of shape [batch_size, channels, height, width]cond is the conditional embeddings ct_start is the sampling step to start from, S′orig is the original image in latent page which we are in paining. If this is not provided, it'll be an image to image transformation.mask is the mask to keep the original image.orig_noise is fixed noise to be added to the original image.uncond_scale is the unconditional guidance scale s. This is used for ϵθ(xt,c)=sϵcond(xt,c)+(s−1)ϵcond(xt,cu)uncond_cond is the conditional embedding for empty prompt cu[email protected]\_grad()255defpaint(self,x:torch.Tensor,cond:torch.Tensor,t\_start:int,\*,256orig:Optional[torch.Tensor]=None,257mask:Optional[torch.Tensor]=None,orig\_noise:Optional[torch.Tensor]=None,258uncond\_scale:float=1.,259uncond\_cond:Optional[torch.Tensor]=None,260):
Get batch size
276bs=x.shape[0]
Time steps to sample at τS‘,τS′−1,…,τ1
279time\_steps=np.flip(self.time\_steps[:t\_start])280281fori,stepinmonit.enum('Paint',time\_steps):
Index i in the list [τ1,τ2,…,τS]
283index=len(time\_steps)-i-1
Time step τi
285ts=x.new\_full((bs,),step,dtype=torch.long)
Sample xτi−1
288x,\_,\_=self.p\_sample(x,cond,ts,step,index=index,289uncond\_scale=uncond\_scale,290uncond\_cond=uncond\_cond)
Replace the masked area with original image
293iforigisnotNone:
Get the qσ,τ(xτi∣x0) for original image in latent space
295orig\_t=self.q\_sample(orig,index,noise=orig\_noise)
Replace the masked area
297x=orig\_t\*mask+x\*(1-mask)
300returnx