Back to Annotated Deep Learning Paper Implementations

Denoising Diffusion Probabilistic Models (DDPM) evaluation/sampling

docs/diffusion/ddpm/evaluate.html

latest10.5 KB
Original Source

homediffusionddpm

View code on Github

#

Denoising Diffusion Probabilistic Models (DDPM) evaluation/sampling

This is the code to generate images and create interpolations between given images.

14importnumpyasnp15importtorch16frommatplotlibimportpyplotasplt17fromtorchvision.transforms.functionalimportto\_pil\_image,resize1819fromlabmlimportexperiment,monit20fromlabml\_nn.diffusion.ddpmimportDenoiseDiffusion,gather21fromlabml\_nn.diffusion.ddpm.experimentimportConfigs

#

Sampler class

24classSampler:

#

  • diffusion is the DenoiseDiffusion instance
  • image_channels is the number of channels in the image
  • image_size is the image size
  • device is the device of the model
29def\_\_init\_\_(self,diffusion:DenoiseDiffusion,image\_channels:int,image\_size:int,device:torch.device):

#

36self.device=device37self.image\_size=image\_size38self.image\_channels=image\_channels39self.diffusion=diffusion

#

T

42self.n\_steps=diffusion.n\_steps

#

ϵθ​(xt​,t)

44self.eps\_model=diffusion.eps\_model

#

βt​

46self.beta=diffusion.beta

#

αt​

48self.alpha=diffusion.alpha

#

αt​ˉ​

50self.alpha\_bar=diffusion.alpha\_bar

#

αˉt−1​

52alpha\_bar\_tm1=torch.cat([self.alpha\_bar.new\_ones((1,)),self.alpha\_bar[:-1]])

#

To calculate

q(xt−1​∣xt​,x0​)μ​t​(xt​,x0​)βt​​​=N(xt−1​;μ​t​(xt​,x0​),βt​​I)=1−αt​ˉ​αˉt−1​​βt​​x0​+1−αt​ˉ​αt​​(1−αˉt−1​)​xt​=1−αt​ˉ​1−αˉt−1​​βt​​

#

βt​~​=1−αt​ˉ​1−αˉt−1​​βt​

64self.beta\_tilde=self.beta\*(1-alpha\_bar\_tm1)/(1-self.alpha\_bar)

#

1−αt​ˉ​αˉt−1​​βt​​

66self.mu\_tilde\_coef1=self.beta\*(alpha\_bar\_tm1\*\*0.5)/(1-self.alpha\_bar)

#

1−αt​ˉ​αt​​(1−αˉt−1​​

68self.mu\_tilde\_coef2=(self.alpha\*\*0.5)\*(1-alpha\_bar\_tm1)/(1-self.alpha\_bar)

#

σ2=β

70self.sigma2=self.beta

#

Helper function to display an image

72defshow\_image(self,img,title=""):

#

74img=img.clip(0,1)75img=img.cpu().numpy()76plt.imshow(img.transpose(1,2,0))77plt.title(title)78plt.show()

#

Helper function to create a video

80defmake\_video(self,frames,path="video.mp4"):

#

82importimageio

#

20 second video

84writer=imageio.get\_writer(path,fps=len(frames)//20)

#

Add each image

86forfinframes:87f=f.clip(0,1)88f=to\_pil\_image(resize(f,[368,368]))89writer.append\_data(np.array(f))

#

91writer.close()

#

Sample an image step-by-step using pθ​(xt−1​∣xt​)

We sample an image step-by-step using pθ​(xt−1​∣xt​) and at each step show the estimate x0​≈x^0​=αˉ​1​(xt​−1−αt​ˉ​​ϵθ​(xt​,t))

93defsample\_animation(self,n\_frames:int=1000,create\_video:bool=True):

#

xT​∼p(xT​)=N(xT​;0,I)

104xt=torch.randn([1,self.image\_channels,self.image\_size,self.image\_size],device=self.device)

#

Interval to log x^0​

107interval=self.n\_steps//n\_frames

#

Frames for video

109frames=[]

#

Sample T steps

111fort\_invinmonit.iterate('Denoise',self.n\_steps):

#

t

113t\_=self.n\_steps-t\_inv-1

#

t in a tensor

115t=xt.new\_full((1,),t\_,dtype=torch.long)

#

ϵθ​(xt​,t)

117eps\_theta=self.eps\_model(xt,t)118ift\_%interval==0:

#

Get x^0​ and add to frames

120x0=self.p\_x0(xt,t,eps\_theta)121frames.append(x0[0])122ifnotcreate\_video:123self.show\_image(x0[0],f"{t\_}")

#

Sample from pθ​(xt−1​∣xt​)

125xt=self.p\_sample(xt,t,eps\_theta)

#

Make video

128ifcreate\_video:129self.make\_video(frames)

#

Interpolate two images x0​ and x0′​

We get xt​∼q(xt​∣x0​) and xt′​∼q(xt′​∣x0​).

Then interpolate to xˉt​=(1−λ)xt​+λx0′​

Then get xˉ0​∼pθ​(x0​∣xˉt​)

  • x1 is x0​
  • x2 is x0′​
  • lambda_ is λ
  • t_ is t
131definterpolate(self,x1:torch.Tensor,x2:torch.Tensor,lambda\_:float,t\_:int=100):

#

Number of samples

150n\_samples=x1.shape[0]

#

t tensor

152t=torch.full((n\_samples,),t\_,device=self.device)

#

xˉt​=(1−λ)xt​+λx0′​

154xt=(1-lambda\_)\*self.diffusion.q\_sample(x1,t)+lambda\_\*self.diffusion.q\_sample(x2,t)

#

xˉ0​∼pθ​(x0​∣xˉt​)

157returnself.\_sample\_x0(xt,t\_)

#

Interpolate two images x0​ and x0′​ and make a video

  • x1 is x0​
  • x2 is x0′​
  • n_frames is the number of frames for the image
  • t_ is t
  • create_video specifies whether to make a video or to show each frame
159definterpolate\_animate(self,x1:torch.Tensor,x2:torch.Tensor,n\_frames:int=100,t\_:int=100,160create\_video=True):

#

Show original images

172self.show\_image(x1,"x1")173self.show\_image(x2,"x2")

#

Add batch dimension

175x1=x1[None,:,:,:]176x2=x2[None,:,:,:]

#

t tensor

178t=torch.full((1,),t\_,device=self.device)

#

xt​∼q(xt​∣x0​)

180x1t=self.diffusion.q\_sample(x1,t)

#

xt′​∼q(xt′​∣x0​)

182x2t=self.diffusion.q\_sample(x2,t)183184frames=[]

#

Get frames with different λ

186foriinmonit.iterate('Interpolate',n\_frames+1,is\_children\_silent=True):

#

λ

188lambda\_=i/n\_frames

#

xˉt​=(1−λ)xt​+λx0′​

190xt=(1-lambda\_)\*x1t+lambda\_\*x2t

#

xˉ0​∼pθ​(x0​∣xˉt​)

192x0=self.\_sample\_x0(xt,t\_)

#

Add to frames

194frames.append(x0[0])

#

Show frame

196ifnotcreate\_video:197self.show\_image(x0[0],f"{lambda\_ :.2f}")

#

Make video

200ifcreate\_video:201self.make\_video(frames)

#

Sample an image using pθ​(xt−1​∣xt​)

  • xt is xt​
  • n_steps is t
203def\_sample\_x0(self,xt:torch.Tensor,n\_steps:int):

#

Number of sampels

212n\_samples=xt.shape[0]

#

Iterate until t steps

214fort\_inmonit.iterate('Denoise',n\_steps):215t=n\_steps-t\_-1

#

Sample from pθ​(xt−1​∣xt​)

217xt=self.diffusion.p\_sample(xt,xt.new\_full((n\_samples,),t,dtype=torch.long))

#

Return x0​

220returnxt

#

Generate images

222defsample(self,n\_samples:int=16):

#

xT​∼p(xT​)=N(xT​;0,I)

227xt=torch.randn([n\_samples,self.image\_channels,self.image\_size,self.image\_size],device=self.device)

#

x0​∼pθ​(x0​∣xt​)

230x0=self.\_sample\_x0(xt,self.n\_steps)

#

Show images

233foriinrange(n\_samples):234self.show\_image(x0[i])

#

Sample from pθ​(xt−1​∣xt​)

pθ​(xt−1​∣xt​)μθ​(xt​,t)​=N(xt−1​;μθ​(xt​,t),σt2​I)=αt​​1​(xt​−1−αt​ˉ​​βt​​ϵθ​(xt​,t))​

236defp\_sample(self,xt:torch.Tensor,t:torch.Tensor,eps\_theta:torch.Tensor):

#

gather αt​ˉ​

249alpha\_bar=gather(self.alpha\_bar,t)

#

αt​

251alpha=gather(self.alpha,t)

#

1−αt​ˉ​​β​

253eps\_coef=(1-alpha)/(1-alpha\_bar)\*\*.5

#

αt​​1​(xt​−1−αt​ˉ​​βt​​ϵθ​(xt​,t))

256mean=1/(alpha\*\*0.5)\*(xt-eps\_coef\*eps\_theta)

#

σ2

258var=gather(self.sigma2,t)

#

ϵ∼N(0,I)

261eps=torch.randn(xt.shape,device=xt.device)

#

Sample

263returnmean+(var\*\*.5)\*eps

#

Estimate x0​

x0​≈x^0​=αˉ​1​(xt​−1−αt​ˉ​​ϵθ​(xt​,t))

265defp\_x0(self,xt:torch.Tensor,t:torch.Tensor,eps:torch.Tensor):

#

gather αt​ˉ​

273alpha\_bar=gather(self.alpha\_bar,t)

#

x0​≈x^0​=αˉ​1​(xt​−1−αt​ˉ​​ϵθ​(xt​,t))

277return(xt-(1-alpha\_bar)\*\*0.5\*eps)/(alpha\_bar\*\*0.5)

#

Generate samples

280defmain():

#

Training experiment run UUID

284run\_uuid="a44333ea251411ec8007d1a1762ed686"

#

Start an evaluation

287experiment.evaluate()

#

Create configs

290configs=Configs()

#

Load custom configuration of the training run

292configs\_dict=experiment.load\_configs(run\_uuid)

#

Set configurations

294experiment.configs(configs,configs\_dict)

#

Initialize

297configs.init()

#

Set PyTorch modules for saving and loading

300experiment.add\_pytorch\_models({'eps\_model':configs.eps\_model})

#

Load training experiment

303experiment.load(run\_uuid)

#

Create sampler

306sampler=Sampler(diffusion=configs.diffusion,307image\_channels=configs.image\_channels,308image\_size=configs.image\_size,309device=configs.device)

#

Start evaluation

312withexperiment.start():

#

No gradients

314withtorch.no\_grad():

#

Sample an image with an denoising animation

316sampler.sample\_animation()317318ifFalse:

#

Get some images fro data

320data=next(iter(configs.data\_loader)).to(configs.device)

#

Create an interpolation animation

323sampler.interpolate\_animate(data[0],data[1])

#

327if\_\_name\_\_=='\_\_main\_\_':328main()

labml.ai