Back to Annotated Deep Learning Paper Implementations

Denoising විසරණය සම්භාවිතාව ආකෘති (DDPM) නියැදීම

docs/si/diffusion/stable_diffusion/sampler/ddpm.html

latest10.3 KB
Original Source

homediffusionstable_diffusionsampler

View code on Github

#

Denoising විසරණය සම්භාවිතාව ආකෘති (DDPM) නියැදීම

සරල DDPM ක්රියාත්මක කිරීම සඳහා අපගේ DDPM ක්රියාත්මක කිරීම වෙත යොමු වන්න. βt​කාලසටහන්αt​ ආදිය සඳහා අපි එකම අංකන භාවිතා කරමු.

16fromtypingimportOptional,List1718importnumpyasnp19importtorch2021fromlabmlimportmonit22fromlabml\_nn.diffusion.stable\_diffusion.latent\_diffusionimportLatentDiffusion23fromlabml\_nn.diffusion.stable\_diffusion.samplerimportDiffusionSampler

#

ඩීඩීපීඑම් නියැදි

මෙය DiffusionSampler මූලික පන්තිය පුළුල් කරයි.

පියවරෙන් පියවර නියැදීමෙන් ශබ්දය නැවත නැවතත් ඉවත් කිරීමෙන් ඩීඩීපීඑම් සාම්පල රූපpθ​(xt−1​∣xt​),

pθ​(xt−1​∣xt​)μt​(xt​,t)βt​​x0​​=N(xt−1​;μθ​(xt​,t),βt​​I)=1−αt​ˉ​αˉt−1​​βt​​x0​+1−αt​ˉ​αt​​(1−αˉt−1​)​xt​=1−αt​ˉ​1−αˉt−1​​βt​=αt​ˉ​​1​xt​−(αt​ˉ​1​−1​)ϵθ​​

26classDDPMSampler(DiffusionSampler):

#

49model:LatentDiffusion

#

  • model ශබ්දය පුරෝකථනය කිරීමේ ආකෘතියයිϵcond​(xt​,c)
51def\_\_init\_\_(self,model:LatentDiffusion):

#

55super().\_\_init\_\_(model)

#

නියැදි පියවර1,2,…,T

58self.time\_steps=np.asarray(list(range(self.n\_steps)))5960withtorch.no\_grad():

#

αt​ˉ​

62alpha\_bar=self.model.alpha\_bar

#

βt​කාලසටහන

64beta=self.model.beta

#

αˉt−1​

66alpha\_bar\_prev=torch.cat([alpha\_bar.new\_tensor([1.]),alpha\_bar[:-1]])

#

αˉ​

69self.sqrt\_alpha\_bar=alpha\_bar\*\*.5

#

1−αˉ​

71self.sqrt\_1m\_alpha\_bar=(1.-alpha\_bar)\*\*.5

#

αt​ˉ​​1​

73self.sqrt\_recip\_alpha\_bar=alpha\_bar\*\*-.5

#

αt​ˉ​1​−1​

75self.sqrt\_recip\_m1\_alpha\_bar=(1/alpha\_bar-1)\*\*.5

#

1−αt​ˉ​1−αˉt−1​​βt​

78variance=beta\*(1.-alpha\_bar\_prev)/(1.-alpha\_bar)

#

කලම්ප ලඝු-සටහනβt​~​

80self.log\_var=torch.log(torch.clamp(variance,min=1e-20))

#

1−αt​ˉ​αˉt−1​​βt​​

82self.mean\_x0\_coef=beta\*(alpha\_bar\_prev\*\*.5)/(1.-alpha\_bar)

#

1−αt​ˉ​αt​​(1−αˉt−1​)​

84self.mean\_xt\_coef=(1.-alpha\_bar\_prev)\*((1-beta)\*\*0.5)/(1.-alpha\_bar)

#

නියැදි ලූප

  • shape ස්වරූපයෙන් ජනනය කරන ලද රූපවල හැඩය[batch_size, channels, height, width]
  • cond කොන්දේසි සහිත කාවැද්දීම් වේc
  • temperature යනු ශබ්දයේ උෂ්ණත්වය (අහඹු ශබ්දය මෙයින් ගුණ කරනු ලැබේ)
  • x_last වේxT​. සපයා නොමැති නම් අහඹු ශබ්දය භාවිතා කරනු ඇත.
  • uncond_scale යනු කොන්දේසි විරහිත මාර්ගෝපදේශs පරිමාණයයි. මෙය භාවිතා වේϵθ​(xt​,c)=sϵcond​(xt​,c)+(s−1)ϵcond​(xt​,cu​)
  • uncond_cond හිස් විමසුමක් සඳහා කොන්දේසි සහිත කාවැද්දීම වේcu​
  • skip_steps මඟ හැරීමට කාල පියවර ගණනt′ වේ. අපි නියැදීම ආරම්භ කරමුT−t′. එවිටx_last යxT−t′​.
[email protected]\_grad()87defsample(self,88shape:List[int],89cond:torch.Tensor,90repeat\_noise:bool=False,91temperature:float=1.,92x\_last:Optional[torch.Tensor]=None,93uncond\_scale:float=1.,94uncond\_cond:Optional[torch.Tensor]=None,95skip\_steps:int=0,96):

#

උපාංගය සහ කණ්ඩායම් ප්රමාණය ලබා ගන්න

113device=self.model.device114bs=shape[0]

#

ලබා ගන්නxT​

117x=x\_lastifx\_lastisnotNoneelsetorch.randn(shape,device=device)

#

නියැදි කිරීමට කාල පියවරT−t′,T−t′−1,…,1

120time\_steps=np.flip(self.time\_steps)[skip\_steps:]

#

නියැදි ලූපය

123forstepinmonit.iterate('Sample',time\_steps):

#

පියවර වේලාවt

125ts=x.new\_full((bs,),step,dtype=torch.long)

#

නියැදියxt−1​

128x,pred\_x0,e\_t=self.p\_sample(x,cond,ts,step,129repeat\_noise=repeat\_noise,130temperature=temperature,131uncond\_scale=uncond\_scale,132uncond\_cond=uncond\_cond)

#

ආපසුx0​

135returnx

#

xt−1​වෙතින් නියැදියpθ​(xt−1​∣xt​)

  • x හැඩයෙන්xt​ යුක්ත වේ[batch_size, channels, height, width]
  • c හැඩයේ කොන්දේසි සහිතc කාවැද්දීම් වේ[batch_size, emb_size]
  • t හැඩයෙන්t යුක්ත වේ[batch_size]
  • step යනු සංඛ්යාංකයක්t ලෙස පියවරයි: පුනරාවර්ත_ශබ්දය: කණ්ඩායමේ සියලුම සාම්පල සඳහා ශබ්දය සමාන විය යුතුද යන්න නිශ්චිතව දක්වා ඇත
  • temperature යනු ශබ්දයේ උෂ්ණත්වය (අහඹු ශබ්දය මෙයින් ගුණ කරනු ලැබේ)
  • uncond_scale යනු කොන්දේසි විරහිත මාර්ගෝපදේශs පරිමාණයයි. මෙය භාවිතා වේϵθ​(xt​,c)=sϵcond​(xt​,c)+(s−1)ϵcond​(xt​,cu​)
  • uncond_cond හිස් විමසුමක් සඳහා කොන්දේසි සහිත කාවැද්දීම වේcu​
[email protected]\_grad()138defp\_sample(self,x:torch.Tensor,c:torch.Tensor,t:torch.Tensor,step:int,139repeat\_noise:bool=False,140temperature:float=1.,141uncond\_scale:float=1.,uncond\_cond:Optional[torch.Tensor]=None):

#

ලබා ගන්නϵθ​

157e\_t=self.get\_eps(x,t,c,158uncond\_scale=uncond\_scale,159uncond\_cond=uncond\_cond)

#

කණ්ඩායම් ප්රමාණය ලබා ගන්න

162bs=x.shape[0]

#

αt​ˉ​​1​

165sqrt\_recip\_alpha\_bar=x.new\_full((bs,1,1,1),self.sqrt\_recip\_alpha\_bar[step])

#

αt​ˉ​1​−1​

167sqrt\_recip\_m1\_alpha\_bar=x.new\_full((bs,1,1,1),self.sqrt\_recip\_m1\_alpha\_bar[step])

#

ධාරාවx0​ සමඟ ගණනය කරන්නϵθ​

x0​=αt​ˉ​​1​xt​−(αt​ˉ​1​−1​)ϵθ​

172x0=sqrt\_recip\_alpha\_bar\*x-sqrt\_recip\_m1\_alpha\_bar\*e\_t

#

1−αt​ˉ​αˉt−1​​βt​​

175mean\_x0\_coef=x.new\_full((bs,1,1,1),self.mean\_x0\_coef[step])

#

1−αt​ˉ​αt​​(1−αˉt−1​)​

177mean\_xt\_coef=x.new\_full((bs,1,1,1),self.mean\_xt\_coef[step])

#

ගණනය කරන්නμt​(xt​,t)

μt​(xt​,t)=1−αt​ˉ​αˉt−1​​βt​​x0​+1−αt​ˉ​αt​​(1−αˉt−1​)​xt​

183mean=mean\_x0\_coef\*x0+mean\_xt\_coef\*x

#

logβt​~​

185log\_var=x.new\_full((bs,1,1,1),self.log\_var[step])

#

t=1(අවසාන පියවර නියැදි ක්රියාවලිය) විට ශබ්දය එකතු නොකරන්න. step එය0 කවදාද යන්න සලකන්නt=1)

189ifstep==0:190noise=0

#

කණ්ඩායමේ සියලුම සාම්පල සඳහා එකම ශබ්දය භාවිතා කරන්නේ නම්

192elifrepeat\_noise:193noise=torch.randn((1,\*x.shape[1:]))

#

එක් එක් නියැදිය සඳහා විවිධ ශබ්ද

195else:196noise=torch.randn(x.shape)

#

උෂ්ණත්වය අනුව ශබ්දය ගුණ කරන්න

199noise=noise\*temperature

#

වෙතින් නියැදිය,

pθ​(xt−1​∣xt​)=N(xt−1​;μθ​(xt​,t),βt​~​I)

204x\_prev=mean+(0.5\*log\_var).exp()\*noise

#

207returnx\_prev,x0,e\_t

#

වෙතින් නියැදියq(xt​∣x0​)

q(xt​∣x0​)=N(xt​;αt​ˉ​​x0​,(1−αt​ˉ​)I)

  • x0 හැඩයෙන්x0​ යුක්ත වේ[batch_size, channels, height, width]
  • index යනු කාල පියවරt දර්ශකයයි
  • noise ශබ්දය,ϵ
[email protected]\_grad()210defq\_sample(self,x0:torch.Tensor,index:int,noise:Optional[torch.Tensor]=None):

#

අහඹු ශබ්දය, ශබ්දය නිශ්චිතව දක්වා නොමැති නම්

222ifnoiseisNone:223noise=torch.randn\_like(x0)

#

වෙතින් නියැදියN(xt​;αt​ˉ​​x0​,(1−αt​ˉ​)I)

226returnself.sqrt\_alpha\_bar[index]\*x0+self.sqrt\_1m\_alpha\_bar[index]\*noise

Trending Research Paperslabml.ai