docs/diffusion/stable_diffusion/sampler/ddpm.html
homediffusionstable_diffusionsampler
For a simpler DDPM implementation refer to our DDPM implementation. We use same notations for αt, βt schedules, etc.
16fromtypingimportOptional,List1718importnumpyasnp19importtorch2021fromlabmlimportmonit22fromlabml\_nn.diffusion.stable\_diffusion.latent\_diffusionimportLatentDiffusion23fromlabml\_nn.diffusion.stable\_diffusion.samplerimportDiffusionSampler
This extends the DiffusionSampler base class.
DDPM samples images by repeatedly removing noise by sampling step by step from pθ(xt−1∣xt),
pθ(xt−1∣xt)μt(xt,t)βtx0=N(xt−1;μθ(xt,t),βtI)=1−αtˉαˉt−1βtx0+1−αtˉαt(1−αˉt−1)xt=1−αtˉ1−αˉt−1βt=αtˉ1xt−(αtˉ1−1)ϵθ
26classDDPMSampler(DiffusionSampler):
49model:LatentDiffusion
model is the model to predict noise ϵcond(xt,c)51def\_\_init\_\_(self,model:LatentDiffusion):
55super().\_\_init\_\_(model)
Sampling steps 1,2,…,T
58self.time\_steps=np.asarray(list(range(self.n\_steps)))5960withtorch.no\_grad():
αtˉ
62alpha\_bar=self.model.alpha\_bar
βt schedule
64beta=self.model.beta
αˉt−1
66alpha\_bar\_prev=torch.cat([alpha\_bar.new\_tensor([1.]),alpha\_bar[:-1]])
αˉ
69self.sqrt\_alpha\_bar=alpha\_bar\*\*.5
1−αˉ
71self.sqrt\_1m\_alpha\_bar=(1.-alpha\_bar)\*\*.5
αtˉ1
73self.sqrt\_recip\_alpha\_bar=alpha\_bar\*\*-.5
αtˉ1−1
75self.sqrt\_recip\_m1\_alpha\_bar=(1/alpha\_bar-1)\*\*.5
1−αtˉ1−αˉt−1βt
78variance=beta\*(1.-alpha\_bar\_prev)/(1.-alpha\_bar)
Clamped log of βt~
80self.log\_var=torch.log(torch.clamp(variance,min=1e-20))
1−αtˉαˉt−1βt
82self.mean\_x0\_coef=beta\*(alpha\_bar\_prev\*\*.5)/(1.-alpha\_bar)
1−αtˉαt(1−αˉt−1)
84self.mean\_xt\_coef=(1.-alpha\_bar\_prev)\*((1-beta)\*\*0.5)/(1.-alpha\_bar)
shape is the shape of the generated images in the form [batch_size, channels, height, width]cond is the conditional embeddings ctemperature is the noise temperature (random noise gets multiplied by this)x_last is xT. If not provided random noise will be used.uncond_scale is the unconditional guidance scale s. This is used for ϵθ(xt,c)=sϵcond(xt,c)+(s−1)ϵcond(xt,cu)uncond_cond is the conditional embedding for empty prompt cuskip_steps is the number of time steps to skip t′. We start sampling from T−t′. And x_last is then xT−t′.[email protected]\_grad()87defsample(self,88shape:List[int],89cond:torch.Tensor,90repeat\_noise:bool=False,91temperature:float=1.,92x\_last:Optional[torch.Tensor]=None,93uncond\_scale:float=1.,94uncond\_cond:Optional[torch.Tensor]=None,95skip\_steps:int=0,96):
Get device and batch size
113device=self.model.device114bs=shape[0]
Get xT
117x=x\_lastifx\_lastisnotNoneelsetorch.randn(shape,device=device)
Time steps to sample at T−t′,T−t′−1,…,1
120time\_steps=np.flip(self.time\_steps)[skip\_steps:]
Sampling loop
123forstepinmonit.iterate('Sample',time\_steps):
Time step t
125ts=x.new\_full((bs,),step,dtype=torch.long)
Sample xt−1
128x,pred\_x0,e\_t=self.p\_sample(x,cond,ts,step,129repeat\_noise=repeat\_noise,130temperature=temperature,131uncond\_scale=uncond\_scale,132uncond\_cond=uncond\_cond)
Return x0
135returnx
x is xt of shape [batch_size, channels, height, width]c is the conditional embeddings c of shape [batch_size, emb_size]t is t of shape [batch_size]step is the step t as an integer :repeat_noise: specified whether the noise should be same for all samples in the batchtemperature is the noise temperature (random noise gets multiplied by this)uncond_scale is the unconditional guidance scale s. This is used for ϵθ(xt,c)=sϵcond(xt,c)+(s−1)ϵcond(xt,cu)uncond_cond is the conditional embedding for empty prompt cu[email protected]\_grad()138defp\_sample(self,x:torch.Tensor,c:torch.Tensor,t:torch.Tensor,step:int,139repeat\_noise:bool=False,140temperature:float=1.,141uncond\_scale:float=1.,uncond\_cond:Optional[torch.Tensor]=None):
Get ϵθ
157e\_t=self.get\_eps(x,t,c,158uncond\_scale=uncond\_scale,159uncond\_cond=uncond\_cond)
Get batch size
162bs=x.shape[0]
αtˉ1
165sqrt\_recip\_alpha\_bar=x.new\_full((bs,1,1,1),self.sqrt\_recip\_alpha\_bar[step])
αtˉ1−1
167sqrt\_recip\_m1\_alpha\_bar=x.new\_full((bs,1,1,1),self.sqrt\_recip\_m1\_alpha\_bar[step])
Calculate x0 with current ϵθ
x0=αtˉ1xt−(αtˉ1−1)ϵθ
172x0=sqrt\_recip\_alpha\_bar\*x-sqrt\_recip\_m1\_alpha\_bar\*e\_t
1−αtˉαˉt−1βt
175mean\_x0\_coef=x.new\_full((bs,1,1,1),self.mean\_x0\_coef[step])
1−αtˉαt(1−αˉt−1)
177mean\_xt\_coef=x.new\_full((bs,1,1,1),self.mean\_xt\_coef[step])
Calculate μt(xt,t)
μt(xt,t)=1−αtˉαˉt−1βtx0+1−αtˉαt(1−αˉt−1)xt
183mean=mean\_x0\_coef\*x0+mean\_xt\_coef\*x
logβt~
185log\_var=x.new\_full((bs,1,1,1),self.log\_var[step])
Do not add noise when t=1 (final step sampling process). Note that step is 0 when t=1)
189ifstep==0:190noise=0
If same noise is used for all samples in the batch
192elifrepeat\_noise:193noise=torch.randn((1,\*x.shape[1:]))
Different noise for each sample
195else:196noise=torch.randn(x.shape)
Multiply noise by the temperature
199noise=noise\*temperature
Sample from,
pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),βt~I)
204x\_prev=mean+(0.5\*log\_var).exp()\*noise
207returnx\_prev,x0,e\_t
q(xt∣x0)=N(xt;αtˉx0,(1−αtˉ)I)
x0 is x0 of shape [batch_size, channels, height, width]index is the time step t indexnoise is the noise, ϵ[email protected]\_grad()210defq\_sample(self,x0:torch.Tensor,index:int,noise:Optional[torch.Tensor]=None):
Random noise, if noise is not specified
222ifnoiseisNone:223noise=torch.randn\_like(x0)
Sample from N(xt;αtˉx0,(1−αtˉ)I)
226returnself.sqrt\_alpha\_bar[index]\*x0+self.sqrt\_1m\_alpha\_bar[index]\*noise