docs/diffusion/ddpm/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/diffusion/ddpm/ init.py)
This is a PyTorch implementation/tutorial of the paper Denoising Diffusion Probabilistic Models.
In simple terms, we get an image from data and add noise step by step. Then We train a model to predict that noise at each step and use the model to generate images.
The following definitions and derivations show how this works. For details please refer to the paper.
The forward process adds noise to the data x0∼q(x0), for T timesteps.
q(xt∣xt−1)=N(xt;1−βtxt−1,βtI)q(x1:T∣x0)=t=1∏Tq(xt∣xt−1)
where β1,…,βT is the variance schedule.
We can sample xt at any timestep t with,
q(xt∣x0)=N(xt;αtˉx0,(1−αtˉ)I)
where αt=1−βt and αtˉ=∏s=1tαs
The reverse process removes noise starting at p(xT)=N(xT;0,I) for T time steps.
pθ(xt−1∣xt)pθ(x0:T)pθ(x0)=N(xt−1;μθ(xt,t),Σθ(xt,t))=pθ(xT)t=1∏Tpθ(xt−1∣xt)=∫pθ(x0:T)dx1:T
θ are the parameters we train.
We optimize the ELBO (from Jenson's inequality) on the negative log likelihood.
E[−logpθ(x0)]≤Eq[−logq(x1:T∣x0)pθ(x0:T)]=L
The loss can be rewritten as follows.
L=Eq[−logq(x1:T∣x0)pθ(x0:T)]=Eq[−logp(xT)−t=1∑Tlogq(xt∣xt−1)pθ(xt−1∣xt)]=Eq[−logq(xT∣x0)p(xT)−t=2∑Tlogq(xt−1∣xt,x0)pθ(xt−1∣xt)−logpθ(x0∣x1)]=Eq[DKL(q(xT∣x0)∥p(xT))+t=2∑TDKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt))−logpθ(x0∣x1)]
DKL(q(xT∣x0)∥p(xT)) is constant since we keep β1,…,βT constant.
The forward process posterior conditioned by x0 is,
q(xt−1∣xt,x0)μt(xt,x0)βt=N(xt−1;μt(xt,x0),βtI)=1−αtˉαˉt−1βtx0+1−αtˉαt(1−αˉt−1)xt=1−αtˉ1−αˉt−1βt
The paper sets Σθ(xt,t)=σt2I where σt2 is set to constants βt or βt~.
Then, pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),σt2I)
For given noise ϵ∼N(0,I) using q(xt∣x0)
xt(x0,ϵ)x0=αtˉx0+1−αtˉϵ=αtˉ1(xt(x0,ϵ)−1−αtˉϵ)
This gives,
Lt−1=DKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt))=Eq[2σt21∥∥μ~(xt,x0)−μθ(xt,t)∥∥2]=Ex0,ϵ[2σt21∥∥αt1(xt(x0,ϵ)−1−αtˉβtϵ)−μθ(xt(x0,ϵ),t)∥∥2]
Re-parameterizing with a model to predict noise
μθ(xt,t)=μ~(xt,αtˉ1(xt−1−αtˉϵθ(xt,t)))=αt1(xt−1−αtˉβtϵθ(xt,t))
where ϵθ is a learned function that predicts ϵ given (xt,t).
This gives,
Lt−1=Ex0,ϵ[2σt2αt(1−αtˉ)βt2∥∥ϵ−ϵθ(αtˉx0+1−αtˉϵ,t)∥∥2]
That is, we are training to predict the noise.
Lsimple(θ)=Et,x0,ϵ[∥∥ϵ−ϵθ(αtˉx0+1−αtˉϵ,t)∥∥2]
This minimizes −logpθ(x0∣x1) when t=1 and Lt−1 for t>1 discarding the weighting in Lt−1. Discarding the weights 2σt2αt(1−αtˉ)βt2 increase the weight given to higher t (which have higher noise levels), therefore increasing the sample quality.
This file implements the loss calculation and a basic sampling method that we use to generate images during training.
Here is the UNet model that gives ϵθ(xt,t) and training code. This file can generate samples and interpolations from a trained model.
162fromtypingimportTuple,Optional163164importtorch165importtorch.nn.functionalasF166importtorch.utils.data167fromtorchimportnn168169fromlabml\_nn.diffusion.ddpm.utilsimportgather
172classDenoiseDiffusion:
eps_model is ϵθ(xt,t) modeln_steps is tdevice is the device to place constants on177def\_\_init\_\_(self,eps\_model:nn.Module,n\_steps:int,device:torch.device):
183super().\_\_init\_\_()184self.eps\_model=eps\_model
Create β1,…,βT linearly increasing variance schedule
187self.beta=torch.linspace(0.0001,0.02,n\_steps).to(device)
αt=1−βt
190self.alpha=1.-self.beta
αtˉ=∏s=1tαs
192self.alpha\_bar=torch.cumprod(self.alpha,dim=0)
T
194self.n\_steps=n\_steps
σ2=β
196self.sigma2=self.beta
q(xt∣x0)=N(xt;αtˉx0,(1−αtˉ)I)
198defq\_xt\_x0(self,x0:torch.Tensor,t:torch.Tensor)-\>Tuple[torch.Tensor,torch.Tensor]:
gather αt and compute αtˉx0
208mean=gather(self.alpha\_bar,t)\*\*0.5\*x0
(1−αtˉ)I
210var=1-gather(self.alpha\_bar,t)
212returnmean,var
q(xt∣x0)=N(xt;αtˉx0,(1−αtˉ)I)
214defq\_sample(self,x0:torch.Tensor,t:torch.Tensor,eps:Optional[torch.Tensor]=None):
ϵ∼N(0,I)
224ifepsisNone:225eps=torch.randn\_like(x0)
get q(xt∣x0)
228mean,var=self.q\_xt\_x0(x0,t)
Sample from q(xt∣x0)
230returnmean+(var\*\*0.5)\*eps
pθ(xt−1∣xt)μθ(xt,t)=N(xt−1;μθ(xt,t),σt2I)=αt1(xt−1−αtˉβtϵθ(xt,t))
232defp\_sample(self,xt:torch.Tensor,t:torch.Tensor):
ϵθ(xt,t)
246eps\_theta=self.eps\_model(xt,t)
gather αtˉ
248alpha\_bar=gather(self.alpha\_bar,t)
αt
250alpha=gather(self.alpha,t)
1−αtˉβ
252eps\_coef=(1-alpha)/(1-alpha\_bar)\*\*.5
αt1(xt−1−αtˉβtϵθ(xt,t))
255mean=1/(alpha\*\*0.5)\*(xt-eps\_coef\*eps\_theta)
σ2
257var=gather(self.sigma2,t)
ϵ∼N(0,I)
260eps=torch.randn(xt.shape,device=xt.device)
Sample
262returnmean+(var\*\*.5)\*eps
Lsimple(θ)=Et,x0,ϵ[∥∥ϵ−ϵθ(αtˉx0+1−αtˉϵ,t)∥∥2]
264defloss(self,x0:torch.Tensor,noise:Optional[torch.Tensor]=None):
Get batch size
273batch\_size=x0.shape[0]
Get random t for each sample in the batch
275t=torch.randint(0,self.n\_steps,(batch\_size,),device=x0.device,dtype=torch.long)
ϵ∼N(0,I)
278ifnoiseisNone:279noise=torch.randn\_like(x0)
Sample xt for q(xt∣x0)
282xt=self.q\_sample(x0,t,eps=noise)
Get ϵθ(αtˉx0+1−αtˉϵ,t)
284eps\_theta=self.eps\_model(xt,t)
MSE loss
287returnF.mse\_loss(noise,eps\_theta)