Back to Annotated Deep Learning Paper Implementations

Denoising Diffusion Probabilistic Models (DDPM)

docs/diffusion/ddpm/index.html

latest8.1 KB
Original Source

homediffusionddpm

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/diffusion/ddpm/ init.py)

#

Denoising Diffusion Probabilistic Models (DDPM)

This is a PyTorch implementation/tutorial of the paper Denoising Diffusion Probabilistic Models.

In simple terms, we get an image from data and add noise step by step. Then We train a model to predict that noise at each step and use the model to generate images.

The following definitions and derivations show how this works. For details please refer to the paper.

Forward Process

The forward process adds noise to the data x0​∼q(x0​), for T timesteps.

q(xt​∣xt−1​)=N(xt​;1−βt​​xt−1​,βt​I)q(x1:T​∣x0​)=t=1∏T​q(xt​∣xt−1​)​

where β1​,…,βT​ is the variance schedule.

We can sample xt​ at any timestep t with,

q(xt​∣x0​)​=N(xt​;αt​ˉ​​x0​,(1−αt​ˉ​)I)​

where αt​=1−βt​ and αt​ˉ​=∏s=1t​αs​

Reverse Process

The reverse process removes noise starting at p(xT​)=N(xT​;0,I) for T time steps.

pθ​(xt−1​∣xt​)pθ​(x0:T​)pθ​(x0​)​=N(xt−1​;μθ​(xt​,t),Σθ​(xt​,t))=pθ​(xT​)t=1∏T​pθ​(xt−1​∣xt​)=∫pθ​(x0:T​)dx1:T​​

θ are the parameters we train.

Loss

We optimize the ELBO (from Jenson's inequality) on the negative log likelihood.

E[−logpθ​(x0​)]​≤Eq​[−logq(x1:T​∣x0​)pθ​(x0:T​)​]=L​

The loss can be rewritten as follows.

L​=Eq​[−logq(x1:T​∣x0​)pθ​(x0:T​)​]=Eq​[−logp(xT​)−t=1∑T​logq(xt​∣xt−1​)pθ​(xt−1​∣xt​)​]=Eq​[−logq(xT​∣x0​)p(xT​)​−t=2∑T​logq(xt−1​∣xt​,x0​)pθ​(xt−1​∣xt​)​−logpθ​(x0​∣x1​)]=Eq​[DKL​(q(xT​∣x0​)∥p(xT​))+t=2∑T​DKL​(q(xt−1​∣xt​,x0​)∥pθ​(xt−1​∣xt​))−logpθ​(x0​∣x1​)]​

DKL​(q(xT​∣x0​)∥p(xT​)) is constant since we keep β1​,…,βT​ constant.

Computing Lt−1​=DKL​(q(xt−1​∣xt​,x0​)∥pθ​(xt−1​∣xt​))

The forward process posterior conditioned by x0​ is,

q(xt−1​∣xt​,x0​)μ​t​(xt​,x0​)βt​​​=N(xt−1​;μ​t​(xt​,x0​),βt​​I)=1−αt​ˉ​αˉt−1​​βt​​x0​+1−αt​ˉ​αt​​(1−αˉt−1​)​xt​=1−αt​ˉ​1−αˉt−1​​βt​​

The paper sets Σθ​(xt​,t)=σt2​I where σt2​ is set to constants βt​ or βt​~​.

Then, pθ​(xt−1​∣xt​)=N(xt−1​;μθ​(xt​,t),σt2​I)

For given noise ϵ∼N(0,I) using q(xt​∣x0​)

xt​(x0​,ϵ)x0​​=αt​ˉ​​x0​+1−αt​ˉ​​ϵ=αt​ˉ​​1​(xt​(x0​,ϵ)−1−αt​ˉ​​ϵ)​

This gives,

Lt−1​​=DKL​(q(xt−1​∣xt​,x0​)∥pθ​(xt−1​∣xt​))=Eq​[2σt2​1​∥∥​μ~​(xt​,x0​)−μθ​(xt​,t)∥∥​2]=Ex0​,ϵ​[2σt2​1​∥∥​αt​​1​(xt​(x0​,ϵ)−1−αt​ˉ​​βt​​ϵ)−μθ​(xt​(x0​,ϵ),t)∥∥​2]​

Re-parameterizing with a model to predict noise

μθ​(xt​,t)​=μ~​(xt​,αt​ˉ​​1​(xt​−1−αt​ˉ​​ϵθ​(xt​,t)))=αt​​1​(xt​−1−αt​ˉ​​βt​​ϵθ​(xt​,t))​

where ϵθ​ is a learned function that predicts ϵ given (xt​,t).

This gives,

Lt−1​​=Ex0​,ϵ​[2σt2​αt​(1−αt​ˉ​)βt​2​∥∥​ϵ−ϵθ​(αt​ˉ​​x0​+1−αt​ˉ​​ϵ,t)∥∥​2]​

That is, we are training to predict the noise.

Simplified loss

Lsimple​(θ)=Et,x0​,ϵ​[∥∥​ϵ−ϵθ​(αt​ˉ​​x0​+1−αt​ˉ​​ϵ,t)∥∥​2]

This minimizes −logpθ​(x0​∣x1​) when t=1 and Lt−1​ for t>1 discarding the weighting in Lt−1​. Discarding the weights 2σt2​αt​(1−αt​ˉ​)βt​2​ increase the weight given to higher t (which have higher noise levels), therefore increasing the sample quality.

This file implements the loss calculation and a basic sampling method that we use to generate images during training.

Here is the UNet model that gives ϵθ​(xt​,t) and training code. This file can generate samples and interpolations from a trained model.

162fromtypingimportTuple,Optional163164importtorch165importtorch.nn.functionalasF166importtorch.utils.data167fromtorchimportnn168169fromlabml\_nn.diffusion.ddpm.utilsimportgather

#

Denoise Diffusion

172classDenoiseDiffusion:

#

  • eps_model is ϵθ​(xt​,t) model
  • n_steps is t
  • device is the device to place constants on
177def\_\_init\_\_(self,eps\_model:nn.Module,n\_steps:int,device:torch.device):

#

183super().\_\_init\_\_()184self.eps\_model=eps\_model

#

Create β1​,…,βT​ linearly increasing variance schedule

187self.beta=torch.linspace(0.0001,0.02,n\_steps).to(device)

#

αt​=1−βt​

190self.alpha=1.-self.beta

#

αt​ˉ​=∏s=1t​αs​

192self.alpha\_bar=torch.cumprod(self.alpha,dim=0)

#

T

194self.n\_steps=n\_steps

#

σ2=β

196self.sigma2=self.beta

#

Get q(xt​∣x0​) distribution

q(xt​∣x0​)​=N(xt​;αt​ˉ​​x0​,(1−αt​ˉ​)I)​

198defq\_xt\_x0(self,x0:torch.Tensor,t:torch.Tensor)-\>Tuple[torch.Tensor,torch.Tensor]:

#

gather αt​ and compute αt​ˉ​​x0​

208mean=gather(self.alpha\_bar,t)\*\*0.5\*x0

#

(1−αt​ˉ​)I

210var=1-gather(self.alpha\_bar,t)

#

212returnmean,var

#

Sample from q(xt​∣x0​)

q(xt​∣x0​)​=N(xt​;αt​ˉ​​x0​,(1−αt​ˉ​)I)​

214defq\_sample(self,x0:torch.Tensor,t:torch.Tensor,eps:Optional[torch.Tensor]=None):

#

ϵ∼N(0,I)

224ifepsisNone:225eps=torch.randn\_like(x0)

#

get q(xt​∣x0​)

228mean,var=self.q\_xt\_x0(x0,t)

#

Sample from q(xt​∣x0​)

230returnmean+(var\*\*0.5)\*eps

#

Sample from pθ​(xt−1​∣xt​)

pθ​(xt−1​∣xt​)μθ​(xt​,t)​=N(xt−1​;μθ​(xt​,t),σt2​I)=αt​​1​(xt​−1−αt​ˉ​​βt​​ϵθ​(xt​,t))​

232defp\_sample(self,xt:torch.Tensor,t:torch.Tensor):

#

ϵθ​(xt​,t)

246eps\_theta=self.eps\_model(xt,t)

#

gather αt​ˉ​

248alpha\_bar=gather(self.alpha\_bar,t)

#

αt​

250alpha=gather(self.alpha,t)

#

1−αt​ˉ​​β​

252eps\_coef=(1-alpha)/(1-alpha\_bar)\*\*.5

#

αt​​1​(xt​−1−αt​ˉ​​βt​​ϵθ​(xt​,t))

255mean=1/(alpha\*\*0.5)\*(xt-eps\_coef\*eps\_theta)

#

σ2

257var=gather(self.sigma2,t)

#

ϵ∼N(0,I)

260eps=torch.randn(xt.shape,device=xt.device)

#

Sample

262returnmean+(var\*\*.5)\*eps

#

Simplified Loss

Lsimple​(θ)=Et,x0​,ϵ​[∥∥​ϵ−ϵθ​(αt​ˉ​​x0​+1−αt​ˉ​​ϵ,t)∥∥​2]

264defloss(self,x0:torch.Tensor,noise:Optional[torch.Tensor]=None):

#

Get batch size

273batch\_size=x0.shape[0]

#

Get random t for each sample in the batch

275t=torch.randint(0,self.n\_steps,(batch\_size,),device=x0.device,dtype=torch.long)

#

ϵ∼N(0,I)

278ifnoiseisNone:279noise=torch.randn\_like(x0)

#

Sample xt​ for q(xt​∣x0​)

282xt=self.q\_sample(x0,t,eps=noise)

#

Get ϵθ​(αt​ˉ​​x0​+1−αt​ˉ​​ϵ,t)

284eps\_theta=self.eps\_model(xt,t)

#

MSE loss

287returnF.mse\_loss(noise,eps\_theta)

labml.ai