docs/diffusion/ddpm/evaluate.html
This is the code to generate images and create interpolations between given images.
14importnumpyasnp15importtorch16frommatplotlibimportpyplotasplt17fromtorchvision.transforms.functionalimportto\_pil\_image,resize1819fromlabmlimportexperiment,monit20fromlabml\_nn.diffusion.ddpmimportDenoiseDiffusion,gather21fromlabml\_nn.diffusion.ddpm.experimentimportConfigs
24classSampler:
diffusion is the DenoiseDiffusion instanceimage_channels is the number of channels in the imageimage_size is the image sizedevice is the device of the model29def\_\_init\_\_(self,diffusion:DenoiseDiffusion,image\_channels:int,image\_size:int,device:torch.device):
36self.device=device37self.image\_size=image\_size38self.image\_channels=image\_channels39self.diffusion=diffusion
T
42self.n\_steps=diffusion.n\_steps
ϵθ(xt,t)
44self.eps\_model=diffusion.eps\_model
βt
46self.beta=diffusion.beta
αt
48self.alpha=diffusion.alpha
αtˉ
50self.alpha\_bar=diffusion.alpha\_bar
αˉt−1
52alpha\_bar\_tm1=torch.cat([self.alpha\_bar.new\_ones((1,)),self.alpha\_bar[:-1]])
To calculate
q(xt−1∣xt,x0)μt(xt,x0)βt=N(xt−1;μt(xt,x0),βtI)=1−αtˉαˉt−1βtx0+1−αtˉαt(1−αˉt−1)xt=1−αtˉ1−αˉt−1βt
βt~=1−αtˉ1−αˉt−1βt
64self.beta\_tilde=self.beta\*(1-alpha\_bar\_tm1)/(1-self.alpha\_bar)
1−αtˉαˉt−1βt
66self.mu\_tilde\_coef1=self.beta\*(alpha\_bar\_tm1\*\*0.5)/(1-self.alpha\_bar)
1−αtˉαt(1−αˉt−1
68self.mu\_tilde\_coef2=(self.alpha\*\*0.5)\*(1-alpha\_bar\_tm1)/(1-self.alpha\_bar)
σ2=β
70self.sigma2=self.beta
Helper function to display an image
72defshow\_image(self,img,title=""):
74img=img.clip(0,1)75img=img.cpu().numpy()76plt.imshow(img.transpose(1,2,0))77plt.title(title)78plt.show()
Helper function to create a video
80defmake\_video(self,frames,path="video.mp4"):
82importimageio
20 second video
84writer=imageio.get\_writer(path,fps=len(frames)//20)
Add each image
86forfinframes:87f=f.clip(0,1)88f=to\_pil\_image(resize(f,[368,368]))89writer.append\_data(np.array(f))
91writer.close()
We sample an image step-by-step using pθ(xt−1∣xt) and at each step show the estimate x0≈x^0=αˉ1(xt−1−αtˉϵθ(xt,t))
93defsample\_animation(self,n\_frames:int=1000,create\_video:bool=True):
xT∼p(xT)=N(xT;0,I)
104xt=torch.randn([1,self.image\_channels,self.image\_size,self.image\_size],device=self.device)
Interval to log x^0
107interval=self.n\_steps//n\_frames
Frames for video
109frames=[]
Sample T steps
111fort\_invinmonit.iterate('Denoise',self.n\_steps):
t
113t\_=self.n\_steps-t\_inv-1
t in a tensor
115t=xt.new\_full((1,),t\_,dtype=torch.long)
ϵθ(xt,t)
117eps\_theta=self.eps\_model(xt,t)118ift\_%interval==0:
Get x^0 and add to frames
120x0=self.p\_x0(xt,t,eps\_theta)121frames.append(x0[0])122ifnotcreate\_video:123self.show\_image(x0[0],f"{t\_}")
Sample from pθ(xt−1∣xt)
125xt=self.p\_sample(xt,t,eps\_theta)
Make video
128ifcreate\_video:129self.make\_video(frames)
We get xt∼q(xt∣x0) and xt′∼q(xt′∣x0).
Then interpolate to xˉt=(1−λ)xt+λx0′
Then get xˉ0∼pθ(x0∣xˉt)
x1 is x0x2 is x0′lambda_ is λt_ is t131definterpolate(self,x1:torch.Tensor,x2:torch.Tensor,lambda\_:float,t\_:int=100):
Number of samples
150n\_samples=x1.shape[0]
t tensor
152t=torch.full((n\_samples,),t\_,device=self.device)
xˉt=(1−λ)xt+λx0′
154xt=(1-lambda\_)\*self.diffusion.q\_sample(x1,t)+lambda\_\*self.diffusion.q\_sample(x2,t)
xˉ0∼pθ(x0∣xˉt)
157returnself.\_sample\_x0(xt,t\_)
x1 is x0x2 is x0′n_frames is the number of frames for the imaget_ is tcreate_video specifies whether to make a video or to show each frame159definterpolate\_animate(self,x1:torch.Tensor,x2:torch.Tensor,n\_frames:int=100,t\_:int=100,160create\_video=True):
Show original images
172self.show\_image(x1,"x1")173self.show\_image(x2,"x2")
Add batch dimension
175x1=x1[None,:,:,:]176x2=x2[None,:,:,:]
t tensor
178t=torch.full((1,),t\_,device=self.device)
xt∼q(xt∣x0)
180x1t=self.diffusion.q\_sample(x1,t)
xt′∼q(xt′∣x0)
182x2t=self.diffusion.q\_sample(x2,t)183184frames=[]
Get frames with different λ
186foriinmonit.iterate('Interpolate',n\_frames+1,is\_children\_silent=True):
λ
188lambda\_=i/n\_frames
xˉt=(1−λ)xt+λx0′
190xt=(1-lambda\_)\*x1t+lambda\_\*x2t
xˉ0∼pθ(x0∣xˉt)
192x0=self.\_sample\_x0(xt,t\_)
Add to frames
194frames.append(x0[0])
Show frame
196ifnotcreate\_video:197self.show\_image(x0[0],f"{lambda\_ :.2f}")
Make video
200ifcreate\_video:201self.make\_video(frames)
xt is xtn_steps is t203def\_sample\_x0(self,xt:torch.Tensor,n\_steps:int):
Number of sampels
212n\_samples=xt.shape[0]
Iterate until t steps
214fort\_inmonit.iterate('Denoise',n\_steps):215t=n\_steps-t\_-1
Sample from pθ(xt−1∣xt)
217xt=self.diffusion.p\_sample(xt,xt.new\_full((n\_samples,),t,dtype=torch.long))
Return x0
220returnxt
222defsample(self,n\_samples:int=16):
xT∼p(xT)=N(xT;0,I)
227xt=torch.randn([n\_samples,self.image\_channels,self.image\_size,self.image\_size],device=self.device)
x0∼pθ(x0∣xt)
230x0=self.\_sample\_x0(xt,self.n\_steps)
Show images
233foriinrange(n\_samples):234self.show\_image(x0[i])
pθ(xt−1∣xt)μθ(xt,t)=N(xt−1;μθ(xt,t),σt2I)=αt1(xt−1−αtˉβtϵθ(xt,t))
236defp\_sample(self,xt:torch.Tensor,t:torch.Tensor,eps\_theta:torch.Tensor):
gather αtˉ
249alpha\_bar=gather(self.alpha\_bar,t)
αt
251alpha=gather(self.alpha,t)
1−αtˉβ
253eps\_coef=(1-alpha)/(1-alpha\_bar)\*\*.5
αt1(xt−1−αtˉβtϵθ(xt,t))
256mean=1/(alpha\*\*0.5)\*(xt-eps\_coef\*eps\_theta)
σ2
258var=gather(self.sigma2,t)
ϵ∼N(0,I)
261eps=torch.randn(xt.shape,device=xt.device)
Sample
263returnmean+(var\*\*.5)\*eps
x0≈x^0=αˉ1(xt−1−αtˉϵθ(xt,t))
265defp\_x0(self,xt:torch.Tensor,t:torch.Tensor,eps:torch.Tensor):
gather αtˉ
273alpha\_bar=gather(self.alpha\_bar,t)
x0≈x^0=αˉ1(xt−1−αtˉϵθ(xt,t))
277return(xt-(1-alpha\_bar)\*\*0.5\*eps)/(alpha\_bar\*\*0.5)
Generate samples
280defmain():
Training experiment run UUID
284run\_uuid="a44333ea251411ec8007d1a1762ed686"
Start an evaluation
287experiment.evaluate()
Create configs
290configs=Configs()
Load custom configuration of the training run
292configs\_dict=experiment.load\_configs(run\_uuid)
Set configurations
294experiment.configs(configs,configs\_dict)
Initialize
297configs.init()
Set PyTorch modules for saving and loading
300experiment.add\_pytorch\_models({'eps\_model':configs.eps\_model})
Load training experiment
303experiment.load(run\_uuid)
Create sampler
306sampler=Sampler(diffusion=configs.diffusion,307image\_channels=configs.image\_channels,308image\_size=configs.image\_size,309device=configs.device)
Start evaluation
312withexperiment.start():
No gradients
314withtorch.no\_grad():
Sample an image with an denoising animation
316sampler.sample\_animation()317318ifFalse:
Get some images fro data
320data=next(iter(configs.data\_loader)).to(configs.device)
Create an interpolation animation
323sampler.interpolate\_animate(data[0],data[1])
327if\_\_name\_\_=='\_\_main\_\_':328main()