Back to Annotated Deep Learning Paper Implementations

Denoising Diffusion Implicit Models (DDIM) Sampling

docs/diffusion/stable_diffusion/sampler/ddim.html

latest10.6 KB
Original Source

homediffusionstable_diffusionsampler

View code on Github

#

Denoising Diffusion Implicit Models (DDIM) Sampling

This implements DDIM sampling from the paper Denoising Diffusion Implicit Models

16fromtypingimportOptional,List1718importnumpyasnp19importtorch2021fromlabmlimportmonit22fromlabml\_nn.diffusion.stable\_diffusion.latent\_diffusionimportLatentDiffusion23fromlabml\_nn.diffusion.stable\_diffusion.samplerimportDiffusionSampler

#

DDIM Sampler

This extends the DiffusionSampler base class.

DDIM samples images by repeatedly removing noise by sampling step by step using,

xτi−1​​​=ατi−1​​​(ατi​​​xτi​​−1−ατi​​​ϵθ​(xτi​​)​)+1−ατi−1​​−στi​​2​⋅ϵθ​(xτi​​)+στi​​ϵτi​​​

where ϵτi​​ is random noise, τ is a subsequence of [1,2,…,T] of length S, and στi​​=η1−ατi​​1−ατi−1​​​​1−ατi−1​​ατi​​​​

Note that, αt​ in DDIM paper refers to αt​ˉ​ from DDPM.

26classDDIMSampler(DiffusionSampler):

#

52model:LatentDiffusion

#

  • model is the model to predict noise ϵcond​(xt​,c)
  • n_steps is the number of DDIM sampling steps, S
  • ddim_discretize specifies how to extract τ from [1,2,…,T]. It can be either uniform or quad .
  • ddim_eta is η used to calculate στi​​. η=0 makes the sampling process deterministic.
54def\_\_init\_\_(self,model:LatentDiffusion,n\_steps:int,ddim\_discretize:str="uniform",ddim\_eta:float=0.):

#

63super().\_\_init\_\_(model)

#

Number of steps, T

65self.n\_steps=model.n\_steps

#

Calculate τ to be uniformly distributed across [1,2,…,T]

68ifddim\_discretize=='uniform':69c=self.n\_steps//n\_steps70self.time\_steps=np.asarray(list(range(0,self.n\_steps,c)))+1

#

Calculate τ to be quadratically distributed across [1,2,…,T]

72elifddim\_discretize=='quad':73self.time\_steps=((np.linspace(0,np.sqrt(self.n\_steps\*.8),n\_steps))\*\*2).astype(int)+174else:75raiseNotImplementedError(ddim\_discretize)7677withtorch.no\_grad():

#

Get αt​ˉ​

79alpha\_bar=self.model.alpha\_bar

#

ατi​​

82self.ddim\_alpha=alpha\_bar[self.time\_steps].clone().to(torch.float32)

#

ατi​​​

84self.ddim\_alpha\_sqrt=torch.sqrt(self.ddim\_alpha)

#

ατi−1​​

86self.ddim\_alpha\_prev=torch.cat([alpha\_bar[0:1],alpha\_bar[self.time\_steps[:-1]]])

#

στi​​=η1−ατi​​1−ατi−1​​​​1−ατi−1​​ατi​​​​

91self.ddim\_sigma=(ddim\_eta\*92((1-self.ddim\_alpha\_prev)/(1-self.ddim\_alpha)\*93(1-self.ddim\_alpha/self.ddim\_alpha\_prev))\*\*.5)

#

1−ατi​​​

96self.ddim\_sqrt\_one\_minus\_alpha=(1.-self.ddim\_alpha)\*\*.5

#

Sampling Loop

  • shape is the shape of the generated images in the form [batch_size, channels, height, width]
  • cond is the conditional embeddings c
  • temperature is the noise temperature (random noise gets multiplied by this)
  • x_last is xτS​​. If not provided random noise will be used.
  • uncond_scale is the unconditional guidance scale s. This is used for ϵθ​(xt​,c)=sϵcond​(xt​,c)+(s−1)ϵcond​(xt​,cu​)
  • uncond_cond is the conditional embedding for empty prompt cu​
  • skip_steps is the number of time steps to skip i′. We start sampling from S−i′. And x_last is then xτS−i′​​.
[email protected]\_grad()99defsample(self,100shape:List[int],101cond:torch.Tensor,102repeat\_noise:bool=False,103temperature:float=1.,104x\_last:Optional[torch.Tensor]=None,105uncond\_scale:float=1.,106uncond\_cond:Optional[torch.Tensor]=None,107skip\_steps:int=0,108):

#

Get device and batch size

125device=self.model.device126bs=shape[0]

#

Get xτS​​

129x=x\_lastifx\_lastisnotNoneelsetorch.randn(shape,device=device)

#

Time steps to sample at τS−i′​,τS−i′−1​,…,τ1​

132time\_steps=np.flip(self.time\_steps)[skip\_steps:]133134fori,stepinmonit.enum('Sample',time\_steps):

#

Index i in the list [τ1​,τ2​,…,τS​]

136index=len(time\_steps)-i-1

#

Time step τi​

138ts=x.new\_full((bs,),step,dtype=torch.long)

#

Sample xτi−1​​

141x,pred\_x0,e\_t=self.p\_sample(x,cond,ts,step,index=index,142repeat\_noise=repeat\_noise,143temperature=temperature,144uncond\_scale=uncond\_scale,145uncond\_cond=uncond\_cond)

#

Return x0​

148returnx

#

Sample xτi−1​​

  • x is xτi​​ of shape [batch_size, channels, height, width]
  • c is the conditional embeddings c of shape [batch_size, emb_size]
  • t is τi​ of shape [batch_size]
  • step is the step τi​ as an integer
  • index is index i in the list [τ1​,τ2​,…,τS​]
  • repeat_noise specified whether the noise should be same for all samples in the batch
  • temperature is the noise temperature (random noise gets multiplied by this)
  • uncond_scale is the unconditional guidance scale s. This is used for ϵθ​(xt​,c)=sϵcond​(xt​,c)+(s−1)ϵcond​(xt​,cu​)
  • uncond_cond is the conditional embedding for empty prompt cu​
[email protected]\_grad()151defp\_sample(self,x:torch.Tensor,c:torch.Tensor,t:torch.Tensor,step:int,index:int,\*,152repeat\_noise:bool=False,153temperature:float=1.,154uncond\_scale:float=1.,155uncond\_cond:Optional[torch.Tensor]=None):

#

Get ϵθ​(xτi​​)

172e\_t=self.get\_eps(x,t,c,173uncond\_scale=uncond\_scale,174uncond\_cond=uncond\_cond)

#

Calculate xτi−1​​ and predicted x0​

177x\_prev,pred\_x0=self.get\_x\_prev\_and\_pred\_x0(e\_t,index,x,178temperature=temperature,179repeat\_noise=repeat\_noise)

#

182returnx\_prev,pred\_x0,e\_t

#

Sample xτi−1​​ given ϵθ​(xτi​​)

184defget\_x\_prev\_and\_pred\_x0(self,e\_t:torch.Tensor,index:int,x:torch.Tensor,\*,185temperature:float,186repeat\_noise:bool):

#

ατi​​

192alpha=self.ddim\_alpha[index]

#

ατi−1​​

194alpha\_prev=self.ddim\_alpha\_prev[index]

#

στi​​

196sigma=self.ddim\_sigma[index]

#

1−ατi​​​

198sqrt\_one\_minus\_alpha=self.ddim\_sqrt\_one\_minus\_alpha[index]

#

Current prediction for x0​, ατi​​​xτi​​−1−ατi​​​ϵθ​(xτi​​)​

202pred\_x0=(x-sqrt\_one\_minus\_alpha\*e\_t)/(alpha\*\*0.5)

#

Direction pointing to xt​ 1−ατi−1​​−στi​​2​⋅ϵθ​(xτi​​)

205dir\_xt=(1.-alpha\_prev-sigma\*\*2).sqrt()\*e\_t

#

No noise is added, when η=0

208ifsigma==0.:209noise=0.

#

If same noise is used for all samples in the batch

211elifrepeat\_noise:212noise=torch.randn((1,\*x.shape[1:]),device=x.device)

#

Different noise for each sample

214else:215noise=torch.randn(x.shape,device=x.device)

#

Multiply noise by the temperature

218noise=noise\*temperature

# xτi−1​​​=ατi−1​​​(ατi​​​xτi​​−1−ατi​​​ϵθ​(xτi​​)​)+1−ατi−1​​−στi​​2​⋅ϵθ​(xτi​​)+στi​​ϵτi​​​

227x\_prev=(alpha\_prev\*\*0.5)\*pred\_x0+dir\_xt+sigma\*noise

#

230returnx\_prev,pred\_x0

#

Sample from qσ,τ​(xτi​​∣x0​)

qσ,τ​(xt​∣x0​)=N(xt​;ατi​​​x0​,(1−ατi​​)I)

  • x0 is x0​ of shape [batch_size, channels, height, width]
  • index is the time step τi​ index i
  • noise is the noise, ϵ
[email protected]\_grad()233defq\_sample(self,x0:torch.Tensor,index:int,noise:Optional[torch.Tensor]=None):

#

Random noise, if noise is not specified

246ifnoiseisNone:247noise=torch.randn\_like(x0)

#

Sample from qσ,τ​(xt​∣x0​)=N(xt​;ατi​​​x0​,(1−ατi​​)I)

252returnself.ddim\_alpha\_sqrt[index]\*x0+self.ddim\_sqrt\_one\_minus\_alpha[index]\*noise

#

Painting Loop

  • x is xS′​ of shape [batch_size, channels, height, width]
  • cond is the conditional embeddings c
  • t_start is the sampling step to start from, S′
  • orig is the original image in latent page which we are in paining. If this is not provided, it'll be an image to image transformation.
  • mask is the mask to keep the original image.
  • orig_noise is fixed noise to be added to the original image.
  • uncond_scale is the unconditional guidance scale s. This is used for ϵθ​(xt​,c)=sϵcond​(xt​,c)+(s−1)ϵcond​(xt​,cu​)
  • uncond_cond is the conditional embedding for empty prompt cu​
[email protected]\_grad()255defpaint(self,x:torch.Tensor,cond:torch.Tensor,t\_start:int,\*,256orig:Optional[torch.Tensor]=None,257mask:Optional[torch.Tensor]=None,orig\_noise:Optional[torch.Tensor]=None,258uncond\_scale:float=1.,259uncond\_cond:Optional[torch.Tensor]=None,260):

#

Get batch size

276bs=x.shape[0]

#

Time steps to sample at τS‘​,τS′−1​,…,τ1​

279time\_steps=np.flip(self.time\_steps[:t\_start])280281fori,stepinmonit.enum('Paint',time\_steps):

#

Index i in the list [τ1​,τ2​,…,τS​]

283index=len(time\_steps)-i-1

#

Time step τi​

285ts=x.new\_full((bs,),step,dtype=torch.long)

#

Sample xτi−1​​

288x,\_,\_=self.p\_sample(x,cond,ts,step,index=index,289uncond\_scale=uncond\_scale,290uncond\_cond=uncond\_cond)

#

Replace the masked area with original image

293iforigisnotNone:

#

Get the qσ,τ​(xτi​​∣x0​) for original image in latent space

295orig\_t=self.q\_sample(orig,index,noise=orig\_noise)

#

Replace the masked area

297x=orig\_t\*mask+x\*(1-mask)

#

300returnx

labml.ai