Back to Annotated Deep Learning Paper Implementations

විසරණ සම්භාවිතා ආකෘති නිරූපණය කිරීම සඳහා යූ-නෙට් ආකෘතිය (ඩීඩීපීඑම්)

docs/si/diffusion/ddpm/unet.html

latest19.9 KB
Original Source

homediffusionddpm

View code on Github

#

විසරණ සම්භාවිතා ආකෘති නිරූපණය කිරීම සඳහා යූ-නෙට් ආකෘතිය (ඩීඩීපීඑම්)

ශබ්දයපුරෝකථනය කිරීම සඳහා මෙය යූ-නෙට් පදනම් කරගත් ආකෘතියකි ϵθ​(xt​,t).

යූ-නෙට්යනු ආදර්ශ රූප සටහනේ යූ හැඩයෙන් එය නම ලබා ගනී. විශේෂාංග සිතියම් විභේදනය ක්රමයෙන් අඩු කිරීමෙන් (අඩක්) සහ විභේදනය වැඩි කිරීමෙන් එය ලබා දී ඇති රූපයක් සකසනු ලැබේ. සෑම විභේදනයකදීම පාස්-හරහා සම්බන්ධතාවයක් ඇත.

මෙමක්රියාත්මක කිරීම මුල් යූ-නෙට් (අවශේෂ කුට්ටි, බහු-හිස අවධානය) සඳහා වෙනස් කිරීම් රාශියක් අඩංගු වන අතර කාල-පියවර කාවැද්දීම් එකතු කරයි t.

24importmath25fromtypingimportOptional,Tuple,Union,List2627importtorch28fromtorchimportnn2930fromlabml\_helpers.moduleimportModule

#

ස්විස්ෂ්ක්රියාකාරී ශ්රිතය

x⋅σ(x)

33classSwish(Module):

#

40defforward(self,x):41returnx\*torch.sigmoid(x)

#

කාවැද්දීම්සඳහා t

44classTimeEmbedding(nn.Module):

#

  • n_channels කාවැද්දීමේ මානයන් ගණන වේ
49def\_\_init\_\_(self,n\_channels:int):

#

53super().\_\_init\_\_()54self.n\_channels=n\_channels

#

පළමුරේඛීය ස්ථරය

56self.lin1=nn.Linear(self.n\_channels//4,self.n\_channels)

#

සක්‍රීයකිරීම

58self.act=Swish()

#

දෙවනරේඛීය ස්ථරය

60self.lin2=nn.Linear(self.n\_channels,self.n\_channels)

#

62defforward(self,t:torch.Tensor):

#

ට්රාන්ස්ෆෝමරයට සමානසයිනොසොයිඩල් ස්ථාන කාවැද්දීම් සාදන්න

PEt,i(1)​PEt,i(2)​​=sin(10000d−1i​t​)=cos(10000d−1i​t​)​

d කොහේද half_dim

72half\_dim=self.n\_channels//873emb=math.log(10\_000)/(half\_dim-1)74emb=torch.exp(torch.arange(half\_dim,device=t.device)\*-emb)75emb=t[:,None]\*emb[None,:]76emb=torch.cat((emb.sin(),emb.cos()),dim=1)

#

එම්එල්පීසමඟ පරිවර්තනය කරන්න

79emb=self.act(self.lin1(emb))80emb=self.lin2(emb)

#

83returnemb

#

අවශේෂකොටස

අවශේෂකොටසකදී කණ්ඩායම් සාමාන්යකරණය සමග convolution ස්ථර දෙකක් ඇත. සෑම විභේදනයක්ම අවශේෂ කොටස් දෙකකින් සකසනු ලැබේ.

86classResidualBlock(Module):

#

  • in_channels ආදාන නාලිකා ගණන

  • out_channels ආදාන නාලිකා ගණන

  • time_channels කාලය පියවර (t) කාවැද්දීම් සංඛ්යාව නාලිකා වේ

  • n_groups කණ්ඩායම් සාමාන්යකරණය සඳහා කණ්ඩායම් සංඛ්යාව වේ

  • dropout හැලහැප්මේ අනුපාතය වේ

94def\_\_init\_\_(self,in\_channels:int,out\_channels:int,time\_channels:int,95n\_groups:int=32,dropout:float=0.1):

#

103super().\_\_init\_\_()

#

කණ්ඩායම්සාමාන්යකරණය සහ පළමු කැටි ගැසුණු ස්ථරය

105self.norm1=nn.GroupNorm(n\_groups,in\_channels)106self.act1=Swish()107self.conv1=nn.Conv2d(in\_channels,out\_channels,kernel\_size=(3,3),padding=(1,1))

#

කණ්ඩායම්සාමාන්යකරණය සහ දෙවන කැටි ගැසුණු ස්තරය

110self.norm2=nn.GroupNorm(n\_groups,out\_channels)111self.act2=Swish()112self.conv2=nn.Conv2d(out\_channels,out\_channels,kernel\_size=(3,3),padding=(1,1))

#

ආදානනාලිකා ගණන ප්රතිදාන නාලිකා ගණනට සමාන නොවේ නම් කෙටිමං සම්බන්ධතාවය ප්රක්ෂේපණය කළ යුතුය

116ifin\_channels!=out\_channels:117self.shortcut=nn.Conv2d(in\_channels,out\_channels,kernel\_size=(1,1))118else:119self.shortcut=nn.Identity()

#

කාලකාවැද්දීම් සඳහා රේඛීය ස්ථරය

122self.time\_emb=nn.Linear(time\_channels,out\_channels)123self.time\_act=Swish()124125self.dropout=nn.Dropout(dropout)

#

  • x හැඩය ඇත [batch_size, in_channels, height, width]
  • t හැඩය ඇත [batch_size, time_channels]
127defforward(self,x:torch.Tensor,t:torch.Tensor):

#

පළමුකැටි ගැසුණු ස්ථරය

133h=self.conv1(self.act1(self.norm1(x)))

#

කාලකාවැද්දීම් එකතු කරන්න

135h+=self.time\_emb(self.time\_act(t))[:,:,None,None]

#

දෙවනකැටි ගැසුණු ස්ථරය

137h=self.conv2(self.dropout(self.act2(self.norm2(h))))

#

කෙටිමංසම්බන්ධතාවය එකතු කර ආපසු යන්න

140returnh+self.shortcut(x)

#

අවධානයවාරණ

මෙය ට්රාන්ස්ෆෝමර් බහු-හිස අවධානයටසමාන වේ.

143classAttentionBlock(Module):

#

  • n_channels යනු ආදානයේ නාලිකා ගණන

  • n_heads බහු-හිස අවධානය යොමු ප්රධානීන් සංඛ්යාව වේ

  • d_k එක් එක් හිසෙහි මානයන් ගණන

  • n_groups කණ්ඩායම් සාමාන්යකරණය සඳහා කණ්ඩායම්සංඛ්යාව වේ

150def\_\_init\_\_(self,n\_channels:int,n\_heads:int=1,d\_k:int=None,n\_groups:int=32):

#

157super().\_\_init\_\_()

#

පෙරනිමි d_k

160ifd\_kisNone:161d\_k=n\_channels

#

සාමාන්යකරණයස්ථරය

163self.norm=nn.GroupNorm(n\_groups,n\_channels)

#

විමසුම, යතුර සහ අගයන් සඳහා ප්රක්ෂේපණ

165self.projection=nn.Linear(n\_channels,n\_heads\*d\_k\*3)

#

අවසානපරිවර්තනය සඳහා රේඛීය ස්ථරය

167self.output=nn.Linear(n\_heads\*d\_k,n\_channels)

#

තිත්නිෂ්පාදන අවධානය සඳහා පරිමාණය

169self.scale=d\_k\*\*-0.5

#

171self.n\_heads=n\_heads172self.d\_k=d\_k

#

  • x හැඩය ඇත [batch_size, in_channels, height, width]
  • t හැඩය ඇත [batch_size, time_channels]
174defforward(self,x:torch.Tensor,t:Optional[torch.Tensor]=None):

#

t භාවිතා නොකෙරේ, නමුත් එය තර්කවල තබා ඇත්තේ අවධානය යොමු කිරීමේ ස්ථර ශ්රිතයේ අත්සන සමඟ ගැලපෙන බැවිනි ResidualBlock .

181\_=t

#

හැඩයලබා ගන්න

183batch\_size,n\_channels,height,width=x.shape

#

x හැඩයට වෙනස් කරන්න [batch_size, seq, n_channels]

185x=x.view(batch\_size,n\_channels,-1).permute(0,2,1)

#

විමසුමලබා ගන්න, යතුර, සහ අගයන් (concatenated) සහ එය හැඩගස්වා [batch_size, seq, n_heads, 3 * d_k]

187qkv=self.projection(x).view(batch\_size,-1,self.n\_heads,3\*self.d\_k)

#

විමසුම, යතුර සහ අගයන් බෙදීම්. ඔවුන් එක් එක් හැඩය ඇත [batch_size, seq, n_heads, d_k]

189q,k,v=torch.chunk(qkv,3,dim=-1)

#

පරිමාණතිත් නිෂ්පාදනයක් ගණනය කරන්න dk​​QK⊤​

191attn=torch.einsum('bihd,bjhd-\>bijh',q,k)\*self.scale

#

අනුක්රමිකමානය ඔස්සේ සොෆ්ට්මැක්ස් seqsoftmax​(dk​​QK⊤​)

193attn=attn.softmax(dim=2)

#

අගයන්අනුව ගුණ කරන්න

195res=torch.einsum('bijh,bjhd-\>bihd',attn,v)

#

නැවතහැඩගස්වන්න [batch_size, seq, n_heads * d_k]

197res=res.view(batch\_size,-1,self.n\_heads\*self.d\_k)

#

බවටපරිවර්තනය කරන්න [batch_size, seq, n_channels]

199res=self.output(res)

#

මඟහැරීමේ සම්බන්ධතාවය එක් කරන්න

202res+=x

#

හැඩයටවෙනස් කරන්න [batch_size, in_channels, height, width]

205res=res.permute(0,2,1).view(batch\_size,n\_channels,height,width)

#

208returnres

#

බ්ලොක්ඩවුන්

මෙයඒකාබද්ධ ResidualBlock හා AttentionBlock . මෙම එක් එක් යෝජනාව දී U-Net පළමු භාගයේ දී භාවිතා වේ.

211classDownBlock(Module):

#

218def\_\_init\_\_(self,in\_channels:int,out\_channels:int,time\_channels:int,has\_attn:bool):219super().\_\_init\_\_()220self.res=ResidualBlock(in\_channels,out\_channels,time\_channels)221ifhas\_attn:222self.attn=AttentionBlock(out\_channels)223else:224self.attn=nn.Identity()

#

226defforward(self,x:torch.Tensor,t:torch.Tensor):227x=self.res(x,t)228x=self.attn(x)229returnx

#

බ්ලොක්දක්වා

මෙයඒකාබද්ධ ResidualBlock හා AttentionBlock . සෑම විභේදනයකදීම යූ-නෙට් හි දෙවන භාගයේදී මේවා භාවිතා වේ.

232classUpBlock(Module):

#

239def\_\_init\_\_(self,in\_channels:int,out\_channels:int,time\_channels:int,has\_attn:bool):240super().\_\_init\_\_()

#

ආදානයටඇත්තේ යූ-නෙට් හි පළමු භාගයේ සිට එකම විභේදනයේ ප්රතිදානය අප සංයුක්ත කරන in_channels + out_channels බැවිනි

243self.res=ResidualBlock(in\_channels+out\_channels,out\_channels,time\_channels)244ifhas\_attn:245self.attn=AttentionBlock(out\_channels)246else:247self.attn=nn.Identity()

#

249defforward(self,x:torch.Tensor,t:torch.Tensor):250x=self.res(x,t)251x=self.attn(x)252returnx

#

මැදකොටස

එයතවත් එකක් ResidualBlock සමඟ ඒකාබද්ධ ResidualBlock වේ. AttentionBlock මෙම කොටස U-Net හි අඩුම විභේදනයෙන් යොදනු ලැබේ.

255classMiddleBlock(Module):

#

263def\_\_init\_\_(self,n\_channels:int,time\_channels:int):264super().\_\_init\_\_()265self.res1=ResidualBlock(n\_channels,n\_channels,time\_channels)266self.attn=AttentionBlock(n\_channels)267self.res2=ResidualBlock(n\_channels,n\_channels,time\_channels)

#

269defforward(self,x:torch.Tensor,t:torch.Tensor):270x=self.res1(x,t)271x=self.attn(x)272x=self.res2(x,t)273returnx

#

විශේෂාංගසිතියම පරිමාණය කරන්න 2×

276classUpsample(nn.Module):

#

281def\_\_init\_\_(self,n\_channels):282super().\_\_init\_\_()283self.conv=nn.ConvTranspose2d(n\_channels,n\_channels,(4,4),(2,2),(1,1))

#

285defforward(self,x:torch.Tensor,t:torch.Tensor):

#

t භාවිතා නොකෙරේ, නමුත් එය තර්කවල තබා ඇත්තේ අවධානය යොමු කිරීමේ ස්ථර ශ්රිතයේ අත්සන සමඟ ගැලපෙන බැවිනි ResidualBlock .

288\_=t289returnself.conv(x)

#

විශේෂාංගසිතියම පරිමාණය කරන්න 21​×

292classDownsample(nn.Module):

#

297def\_\_init\_\_(self,n\_channels):298super().\_\_init\_\_()299self.conv=nn.Conv2d(n\_channels,n\_channels,(3,3),(2,2),(1,1))

#

301defforward(self,x:torch.Tensor,t:torch.Tensor):

#

t භාවිතා නොකෙරේ, නමුත් එය තර්කවල තබා ඇත්තේ අවධානය යොමු කිරීමේ ස්ථර ශ්රිතයේ අත්සන සමඟ ගැලපෙන බැවිනි ResidualBlock .

304\_=t305returnself.conv(x)

#

යූ-නෙට්

308classUNet(Module):

#

  • image_channels යනු රූපයේ නාලිකා ගණන. 3 RGB සඳහා.
  • n_channels ආරම්භක විශේෂාංග සිතියමේ නාලිකා ගණන අපි රූපය බවට පරිවර්තනය කරමු
  • ch_mults යනු එක් එක් විභේදනයේ නාලිකා අංක ලැයිස්තුවයි. නාලිකා ගණන වේ ch_mults[i] * n_channels
  • is_attn යනු එක් එක් විභේදනයේ දී අවධානය යොමු කළ යුතුද යන්න පෙන්නුම් කරන බූලියන් ලැයිස්තුවකි
  • n_blocks එක් එක් යෝජනාව UpDownBlocks දී සංඛ්යාව වේ
313def\_\_init\_\_(self,image\_channels:int=3,n\_channels:int=64,314ch\_mults:Union[Tuple[int,...],List[int]]=(1,2,2,4),315is\_attn:Union[Tuple[bool,...],List[int]]=(False,False,True,True),316n\_blocks:int=2):

#

324super().\_\_init\_\_()

#

යෝජනාගණන

327n\_resolutions=len(ch\_mults)

#

විශේෂාංගසිතියමට ව්යාපෘති රූපය

330self.image\_proj=nn.Conv2d(image\_channels,n\_channels,kernel\_size=(3,3),padding=(1,1))

#

කාලයකාවැද්දීම ස්ථරය. කාල කාවැද්දීම n_channels * 4 නාලිකා ඇත

333self.time\_emb=TimeEmbedding(n\_channels\*4)

#

U-Netපළමු භාගය - අඩු යෝජනාව

336down=[]

#

නාලිකාගණන

338out\_channels=in\_channels=n\_channels

#

එක්එක් යෝජනාව සඳහා

340foriinrange(n\_resolutions):

#

මෙමවිභේදනයේ ප්රතිදාන නාලිකා ගණන

342out\_channels=in\_channels\*ch\_mults[i]

#

එකතුකරන්න n_blocks

344for\_inrange(n\_blocks):345down.append(DownBlock(in\_channels,out\_channels,n\_channels\*4,is\_attn[i]))346in\_channels=out\_channels

#

පසුගියහැර අනෙකුත් සියලු යෝජනා දී ආදර්ශ පහළ

348ifi\<n\_resolutions-1:349down.append(Downsample(in\_channels))

#

මොඩියුලකට්ටලය ඒකාබද්ධ කරන්න

352self.down=nn.ModuleList(down)

#

මැදකොටස

355self.middle=MiddleBlock(out\_channels,n\_channels\*4,)

#

යූ-නෙට්හි දෙවන භාගය - වැඩි කිරීමේ විභේදනය

358up=[]

#

නාලිකාගණන

360in\_channels=out\_channels

#

එක්එක් යෝජනාව සඳහා

362foriinreversed(range(n\_resolutions)):

#

n_blocks එකම විභේදනයේ

364out\_channels=in\_channels365for\_inrange(n\_blocks):366up.append(UpBlock(in\_channels,out\_channels,n\_channels\*4,is\_attn[i]))

#

නාලිකාගණන අඩු කිරීම සඳහා අවසාන කොටස

368out\_channels=in\_channels//ch\_mults[i]369up.append(UpBlock(in\_channels,out\_channels,n\_channels\*4,is\_attn[i]))370in\_channels=out\_channels

#

පසුගියහැර අනෙකුත් සියලු යෝජනා දී ආදර්ශ දක්වා

372ifi\>0:373up.append(Upsample(in\_channels))

#

මොඩියුලකට්ටලය ඒකාබද්ධ කරන්න

376self.up=nn.ModuleList(up)

#

අවසානසාමාන්යකරණය සහ කැටි ගැසුණු ස්ථරය

379self.norm=nn.GroupNorm(8,n\_channels)380self.act=Swish()381self.final=nn.Conv2d(in\_channels,image\_channels,kernel\_size=(3,3),padding=(1,1))

#

  • x හැඩය ඇත [batch_size, in_channels, height, width]
  • t හැඩය ඇත [batch_size]
383defforward(self,x:torch.Tensor,t:torch.Tensor):

#

කාල-පියවරකාවැද්දීම් ලබා ගන්න

390t=self.time\_emb(t)

#

රූපප්රක්ෂේපණය ලබා ගන්න

393x=self.image\_proj(x)

#

h මඟ හැරීමේ සම්බන්ධතාවය සඳහා එක් එක් විභේදනයේ ප්රතිදානයන් ගබඩා කරනු ඇත

396h=[x]

#

යූ-නෙට්හි පළමු භාගය

398forminself.down:399x=m(x,t)400h.append(x)

#

මැද(පහළ)

403x=self.middle(x,t)

#

යූ-නෙට්හි දෙවන භාගය

406forminself.up:407ifisinstance(m,Upsample):408x=m(x,t)409else:

#

යූ-නෙට්හි පළමු භාගයේ සිට මඟ හැරීමේ සම්බන්ධතාවය ලබාගෙන සංයුක්ත කරන්න

411s=h.pop()412x=torch.cat((x,s),dim=1)

#

414x=m(x,t)

#

අවසානසාමාන්යකරණය සහ කැටි කිරීම

417returnself.final(self.act(self.norm(x)))

Trending Research Paperslabml.ai