docs/si/diffusion/ddpm/unet.html
ශබ්දයපුරෝකථනය කිරීම සඳහා මෙය යූ-නෙට් පදනම් කරගත් ආකෘතියකි ϵθ(xt,t).
යූ-නෙට්යනු ආදර්ශ රූප සටහනේ යූ හැඩයෙන් එය නම ලබා ගනී. විශේෂාංග සිතියම් විභේදනය ක්රමයෙන් අඩු කිරීමෙන් (අඩක්) සහ විභේදනය වැඩි කිරීමෙන් එය ලබා දී ඇති රූපයක් සකසනු ලැබේ. සෑම විභේදනයකදීම පාස්-හරහා සම්බන්ධතාවයක් ඇත.
මෙමක්රියාත්මක කිරීම මුල් යූ-නෙට් (අවශේෂ කුට්ටි, බහු-හිස අවධානය) සඳහා වෙනස් කිරීම් රාශියක් අඩංගු වන අතර කාල-පියවර කාවැද්දීම් එකතු කරයි t.
24importmath25fromtypingimportOptional,Tuple,Union,List2627importtorch28fromtorchimportnn2930fromlabml\_helpers.moduleimportModule
x⋅σ(x)
33classSwish(Module):
40defforward(self,x):41returnx\*torch.sigmoid(x)
44classTimeEmbedding(nn.Module):
n_channels කාවැද්දීමේ මානයන් ගණන වේ49def\_\_init\_\_(self,n\_channels:int):
53super().\_\_init\_\_()54self.n\_channels=n\_channels
පළමුරේඛීය ස්ථරය
56self.lin1=nn.Linear(self.n\_channels//4,self.n\_channels)
සක්රීයකිරීම
58self.act=Swish()
දෙවනරේඛීය ස්ථරය
60self.lin2=nn.Linear(self.n\_channels,self.n\_channels)
62defforward(self,t:torch.Tensor):
ට්රාන්ස්ෆෝමරයට සමානසයිනොසොයිඩල් ස්ථාන කාවැද්දීම් සාදන්න
PEt,i(1)PEt,i(2)=sin(10000d−1it)=cos(10000d−1it)
d කොහේද half_dim
72half\_dim=self.n\_channels//873emb=math.log(10\_000)/(half\_dim-1)74emb=torch.exp(torch.arange(half\_dim,device=t.device)\*-emb)75emb=t[:,None]\*emb[None,:]76emb=torch.cat((emb.sin(),emb.cos()),dim=1)
එම්එල්පීසමඟ පරිවර්තනය කරන්න
79emb=self.act(self.lin1(emb))80emb=self.lin2(emb)
83returnemb
අවශේෂකොටසකදී කණ්ඩායම් සාමාන්යකරණය සමග convolution ස්ථර දෙකක් ඇත. සෑම විභේදනයක්ම අවශේෂ කොටස් දෙකකින් සකසනු ලැබේ.
86classResidualBlock(Module):
in_channels ආදාන නාලිකා ගණන
out_channels ආදාන නාලිකා ගණන
time_channels කාලය පියවර (t) කාවැද්දීම් සංඛ්යාව නාලිකා වේ
n_groups කණ්ඩායම් සාමාන්යකරණය සඳහා කණ්ඩායම් සංඛ්යාව වේ
dropout හැලහැප්මේ අනුපාතය වේ
94def\_\_init\_\_(self,in\_channels:int,out\_channels:int,time\_channels:int,95n\_groups:int=32,dropout:float=0.1):
103super().\_\_init\_\_()
කණ්ඩායම්සාමාන්යකරණය සහ පළමු කැටි ගැසුණු ස්ථරය
105self.norm1=nn.GroupNorm(n\_groups,in\_channels)106self.act1=Swish()107self.conv1=nn.Conv2d(in\_channels,out\_channels,kernel\_size=(3,3),padding=(1,1))
කණ්ඩායම්සාමාන්යකරණය සහ දෙවන කැටි ගැසුණු ස්තරය
110self.norm2=nn.GroupNorm(n\_groups,out\_channels)111self.act2=Swish()112self.conv2=nn.Conv2d(out\_channels,out\_channels,kernel\_size=(3,3),padding=(1,1))
ආදානනාලිකා ගණන ප්රතිදාන නාලිකා ගණනට සමාන නොවේ නම් කෙටිමං සම්බන්ධතාවය ප්රක්ෂේපණය කළ යුතුය
116ifin\_channels!=out\_channels:117self.shortcut=nn.Conv2d(in\_channels,out\_channels,kernel\_size=(1,1))118else:119self.shortcut=nn.Identity()
කාලකාවැද්දීම් සඳහා රේඛීය ස්ථරය
122self.time\_emb=nn.Linear(time\_channels,out\_channels)123self.time\_act=Swish()124125self.dropout=nn.Dropout(dropout)
x හැඩය ඇත [batch_size, in_channels, height, width]t හැඩය ඇත [batch_size, time_channels]127defforward(self,x:torch.Tensor,t:torch.Tensor):
පළමුකැටි ගැසුණු ස්ථරය
133h=self.conv1(self.act1(self.norm1(x)))
කාලකාවැද්දීම් එකතු කරන්න
135h+=self.time\_emb(self.time\_act(t))[:,:,None,None]
දෙවනකැටි ගැසුණු ස්ථරය
137h=self.conv2(self.dropout(self.act2(self.norm2(h))))
කෙටිමංසම්බන්ධතාවය එකතු කර ආපසු යන්න
140returnh+self.shortcut(x)
මෙය ට්රාන්ස්ෆෝමර් බහු-හිස අවධානයටසමාන වේ.
143classAttentionBlock(Module):
n_channels යනු ආදානයේ නාලිකා ගණන
n_heads බහු-හිස අවධානය යොමු ප්රධානීන් සංඛ්යාව වේ
d_k එක් එක් හිසෙහි මානයන් ගණන
n_groups කණ්ඩායම් සාමාන්යකරණය සඳහා කණ්ඩායම්සංඛ්යාව වේ
150def\_\_init\_\_(self,n\_channels:int,n\_heads:int=1,d\_k:int=None,n\_groups:int=32):
157super().\_\_init\_\_()
පෙරනිමි d_k
160ifd\_kisNone:161d\_k=n\_channels
සාමාන්යකරණයස්ථරය
163self.norm=nn.GroupNorm(n\_groups,n\_channels)
විමසුම, යතුර සහ අගයන් සඳහා ප්රක්ෂේපණ
165self.projection=nn.Linear(n\_channels,n\_heads\*d\_k\*3)
අවසානපරිවර්තනය සඳහා රේඛීය ස්ථරය
167self.output=nn.Linear(n\_heads\*d\_k,n\_channels)
තිත්නිෂ්පාදන අවධානය සඳහා පරිමාණය
169self.scale=d\_k\*\*-0.5
171self.n\_heads=n\_heads172self.d\_k=d\_k
x හැඩය ඇත [batch_size, in_channels, height, width]t හැඩය ඇත [batch_size, time_channels]174defforward(self,x:torch.Tensor,t:Optional[torch.Tensor]=None):
t භාවිතා නොකෙරේ, නමුත් එය තර්කවල තබා ඇත්තේ අවධානය යොමු කිරීමේ ස්ථර ශ්රිතයේ අත්සන සමඟ ගැලපෙන බැවිනි ResidualBlock .
181\_=t
හැඩයලබා ගන්න
183batch\_size,n\_channels,height,width=x.shape
x හැඩයට වෙනස් කරන්න [batch_size, seq, n_channels]
185x=x.view(batch\_size,n\_channels,-1).permute(0,2,1)
විමසුමලබා ගන්න, යතුර, සහ අගයන් (concatenated) සහ එය හැඩගස්වා [batch_size, seq, n_heads, 3 * d_k]
187qkv=self.projection(x).view(batch\_size,-1,self.n\_heads,3\*self.d\_k)
විමසුම, යතුර සහ අගයන් බෙදීම්. ඔවුන් එක් එක් හැඩය ඇත [batch_size, seq, n_heads, d_k]
189q,k,v=torch.chunk(qkv,3,dim=-1)
පරිමාණතිත් නිෂ්පාදනයක් ගණනය කරන්න dkQK⊤
191attn=torch.einsum('bihd,bjhd-\>bijh',q,k)\*self.scale
අනුක්රමිකමානය ඔස්සේ සොෆ්ට්මැක්ස් seqsoftmax(dkQK⊤)
193attn=attn.softmax(dim=2)
අගයන්අනුව ගුණ කරන්න
195res=torch.einsum('bijh,bjhd-\>bihd',attn,v)
නැවතහැඩගස්වන්න [batch_size, seq, n_heads * d_k]
197res=res.view(batch\_size,-1,self.n\_heads\*self.d\_k)
බවටපරිවර්තනය කරන්න [batch_size, seq, n_channels]
199res=self.output(res)
මඟහැරීමේ සම්බන්ධතාවය එක් කරන්න
202res+=x
හැඩයටවෙනස් කරන්න [batch_size, in_channels, height, width]
205res=res.permute(0,2,1).view(batch\_size,n\_channels,height,width)
208returnres
මෙයඒකාබද්ධ ResidualBlock හා AttentionBlock . මෙම එක් එක් යෝජනාව දී U-Net පළමු භාගයේ දී භාවිතා වේ.
211classDownBlock(Module):
218def\_\_init\_\_(self,in\_channels:int,out\_channels:int,time\_channels:int,has\_attn:bool):219super().\_\_init\_\_()220self.res=ResidualBlock(in\_channels,out\_channels,time\_channels)221ifhas\_attn:222self.attn=AttentionBlock(out\_channels)223else:224self.attn=nn.Identity()
226defforward(self,x:torch.Tensor,t:torch.Tensor):227x=self.res(x,t)228x=self.attn(x)229returnx
මෙයඒකාබද්ධ ResidualBlock හා AttentionBlock . සෑම විභේදනයකදීම යූ-නෙට් හි දෙවන භාගයේදී මේවා භාවිතා වේ.
232classUpBlock(Module):
239def\_\_init\_\_(self,in\_channels:int,out\_channels:int,time\_channels:int,has\_attn:bool):240super().\_\_init\_\_()
ආදානයටඇත්තේ යූ-නෙට් හි පළමු භාගයේ සිට එකම විභේදනයේ ප්රතිදානය අප සංයුක්ත කරන in_channels + out_channels බැවිනි
243self.res=ResidualBlock(in\_channels+out\_channels,out\_channels,time\_channels)244ifhas\_attn:245self.attn=AttentionBlock(out\_channels)246else:247self.attn=nn.Identity()
249defforward(self,x:torch.Tensor,t:torch.Tensor):250x=self.res(x,t)251x=self.attn(x)252returnx
එයතවත් එකක් ResidualBlock සමඟ ඒකාබද්ධ ResidualBlock වේ. AttentionBlock මෙම කොටස U-Net හි අඩුම විභේදනයෙන් යොදනු ලැබේ.
255classMiddleBlock(Module):
263def\_\_init\_\_(self,n\_channels:int,time\_channels:int):264super().\_\_init\_\_()265self.res1=ResidualBlock(n\_channels,n\_channels,time\_channels)266self.attn=AttentionBlock(n\_channels)267self.res2=ResidualBlock(n\_channels,n\_channels,time\_channels)
269defforward(self,x:torch.Tensor,t:torch.Tensor):270x=self.res1(x,t)271x=self.attn(x)272x=self.res2(x,t)273returnx
276classUpsample(nn.Module):
281def\_\_init\_\_(self,n\_channels):282super().\_\_init\_\_()283self.conv=nn.ConvTranspose2d(n\_channels,n\_channels,(4,4),(2,2),(1,1))
285defforward(self,x:torch.Tensor,t:torch.Tensor):
t භාවිතා නොකෙරේ, නමුත් එය තර්කවල තබා ඇත්තේ අවධානය යොමු කිරීමේ ස්ථර ශ්රිතයේ අත්සන සමඟ ගැලපෙන බැවිනි ResidualBlock .
288\_=t289returnself.conv(x)
292classDownsample(nn.Module):
297def\_\_init\_\_(self,n\_channels):298super().\_\_init\_\_()299self.conv=nn.Conv2d(n\_channels,n\_channels,(3,3),(2,2),(1,1))
301defforward(self,x:torch.Tensor,t:torch.Tensor):
t භාවිතා නොකෙරේ, නමුත් එය තර්කවල තබා ඇත්තේ අවධානය යොමු කිරීමේ ස්ථර ශ්රිතයේ අත්සන සමඟ ගැලපෙන බැවිනි ResidualBlock .
304\_=t305returnself.conv(x)
308classUNet(Module):
image_channels යනු රූපයේ නාලිකා ගණන. 3 RGB සඳහා.n_channels ආරම්භක විශේෂාංග සිතියමේ නාලිකා ගණන අපි රූපය බවට පරිවර්තනය කරමුch_mults යනු එක් එක් විභේදනයේ නාලිකා අංක ලැයිස්තුවයි. නාලිකා ගණන වේ ch_mults[i] * n_channelsis_attn යනු එක් එක් විභේදනයේ දී අවධානය යොමු කළ යුතුද යන්න පෙන්නුම් කරන බූලියන් ලැයිස්තුවකිn_blocks එක් එක් යෝජනාව UpDownBlocks දී සංඛ්යාව වේ313def\_\_init\_\_(self,image\_channels:int=3,n\_channels:int=64,314ch\_mults:Union[Tuple[int,...],List[int]]=(1,2,2,4),315is\_attn:Union[Tuple[bool,...],List[int]]=(False,False,True,True),316n\_blocks:int=2):
324super().\_\_init\_\_()
යෝජනාගණන
327n\_resolutions=len(ch\_mults)
විශේෂාංගසිතියමට ව්යාපෘති රූපය
330self.image\_proj=nn.Conv2d(image\_channels,n\_channels,kernel\_size=(3,3),padding=(1,1))
කාලයකාවැද්දීම ස්ථරය. කාල කාවැද්දීම n_channels * 4 නාලිකා ඇත
333self.time\_emb=TimeEmbedding(n\_channels\*4)
336down=[]
නාලිකාගණන
338out\_channels=in\_channels=n\_channels
එක්එක් යෝජනාව සඳහා
340foriinrange(n\_resolutions):
මෙමවිභේදනයේ ප්රතිදාන නාලිකා ගණන
342out\_channels=in\_channels\*ch\_mults[i]
එකතුකරන්න n_blocks
344for\_inrange(n\_blocks):345down.append(DownBlock(in\_channels,out\_channels,n\_channels\*4,is\_attn[i]))346in\_channels=out\_channels
පසුගියහැර අනෙකුත් සියලු යෝජනා දී ආදර්ශ පහළ
348ifi\<n\_resolutions-1:349down.append(Downsample(in\_channels))
මොඩියුලකට්ටලය ඒකාබද්ධ කරන්න
352self.down=nn.ModuleList(down)
මැදකොටස
355self.middle=MiddleBlock(out\_channels,n\_channels\*4,)
358up=[]
නාලිකාගණන
360in\_channels=out\_channels
එක්එක් යෝජනාව සඳහා
362foriinreversed(range(n\_resolutions)):
n_blocks එකම විභේදනයේ
364out\_channels=in\_channels365for\_inrange(n\_blocks):366up.append(UpBlock(in\_channels,out\_channels,n\_channels\*4,is\_attn[i]))
නාලිකාගණන අඩු කිරීම සඳහා අවසාන කොටස
368out\_channels=in\_channels//ch\_mults[i]369up.append(UpBlock(in\_channels,out\_channels,n\_channels\*4,is\_attn[i]))370in\_channels=out\_channels
පසුගියහැර අනෙකුත් සියලු යෝජනා දී ආදර්ශ දක්වා
372ifi\>0:373up.append(Upsample(in\_channels))
මොඩියුලකට්ටලය ඒකාබද්ධ කරන්න
376self.up=nn.ModuleList(up)
අවසානසාමාන්යකරණය සහ කැටි ගැසුණු ස්ථරය
379self.norm=nn.GroupNorm(8,n\_channels)380self.act=Swish()381self.final=nn.Conv2d(in\_channels,image\_channels,kernel\_size=(3,3),padding=(1,1))
x හැඩය ඇත [batch_size, in_channels, height, width]t හැඩය ඇත [batch_size]383defforward(self,x:torch.Tensor,t:torch.Tensor):
කාල-පියවරකාවැද්දීම් ලබා ගන්න
390t=self.time\_emb(t)
රූපප්රක්ෂේපණය ලබා ගන්න
393x=self.image\_proj(x)
h මඟ හැරීමේ සම්බන්ධතාවය සඳහා එක් එක් විභේදනයේ ප්රතිදානයන් ගබඩා කරනු ඇත
396h=[x]
යූ-නෙට්හි පළමු භාගය
398forminself.down:399x=m(x,t)400h.append(x)
මැද(පහළ)
403x=self.middle(x,t)
යූ-නෙට්හි දෙවන භාගය
406forminself.up:407ifisinstance(m,Upsample):408x=m(x,t)409else:
යූ-නෙට්හි පළමු භාගයේ සිට මඟ හැරීමේ සම්බන්ධතාවය ලබාගෙන සංයුක්ත කරන්න
411s=h.pop()412x=torch.cat((x,s),dim=1)
414x=m(x,t)
අවසානසාමාන්යකරණය සහ කැටි කිරීම
417returnself.final(self.act(self.norm(x)))