docs/si/diffusion/stable_diffusion/model/unet.html
homediffusionstable_diffusionmodel
මෙය ලබා දෙන යූ-නෙට් ක්රියාත්මක කරයිϵcond(xt,c)
අපි ආදර්ශ අර්ථ දැක්වීම තබා ඇති අතර කොම්විස්/ස්ථාවර විසරණ සිට නොවෙනස්ව නම් කිරීම අපට මුරපොලවල් කෙලින්ම පැටවිය හැකි වන පරිදි.
18importmath19fromtypingimportList2021importnumpyasnp22importtorch23importtorch.nnasnn24importtorch.nn.functionalasF2526fromlabml\_nn.diffusion.stable\_diffusion.model.unet\_attentionimportSpatialTransformer
29classUNetModel(nn.Module):
in_channels ආදාන විශේෂාංග සිතියමේ නාලිකා ගණන වේout_channels ප්රතිදාන විශේෂාංග සිතියමේ නාලිකා ගණන වේchannels ආකෘතිය සඳහා මූලික නාලිකා ගණන වේn_res_blocks එක් එක් මට්ටමේ අවශේෂ කුට්ටි ගණනattention_levels අවධානය යොමු කළ යුතු මට්ටම් වේchannel_multipliers එක් එක් මට්ටම් සඳහා නාලිකා ගණන සඳහා බහුකාර්ය සාධක වේn_heads ට්රාන්ස්ෆෝමර්වල අවධානය යොමු කිරීමේ හිස් සංඛ්යාව34def\_\_init\_\_(35self,\*,36in\_channels:int,37out\_channels:int,38channels:int,39n\_res\_blocks:int,40attention\_levels:List[int],41channel\_multipliers:List[int],42n\_heads:int,43tf\_layers:int=1,44d\_cond:int=768):
54super().\_\_init\_\_()55self.channels=channels
මට්ටම් ගණන
58levels=len(channel\_multipliers)
ප්රමාණ කාල කාවැද්දීම්
60d\_time\_emb=channels\*461self.time\_embed=nn.Sequential(62nn.Linear(channels,d\_time\_emb),63nn.SiLU(),64nn.Linear(d\_time\_emb,d\_time\_emb),65)
U-Net ආදාන අඩක්
68self.input\_blocks=nn.ModuleList()
ආදානය සිතියම්3×3 ගත කරන මූලික ව්යාවච්ඡාවchannels . විවිධ මොඩියුලවල විවිධ ඉදිරි ක්රියාකාරී අත්සන් ඇති බැවින් කුට්ටිTimestepEmbedSequential මොඩියුලයේ ඔතා ඇත; නිදසුනක් ලෙස, කැටි ගැසීමේදී විශේෂාංග සිතියම පමණක් පිළිගන්නා අතර අවශේෂ කොටස් විශේෂාංග සිතියම සහ වේලාව කාවැද්දීම පිළිගනී. TimestepEmbedSequential ඒ අනුව ඔවුන් අමතයි.
75self.input\_blocks.append(TimestepEmbedSequential(76nn.Conv2d(in\_channels,channels,3,padding=1)))
යූ-නෙට් හි ආදාන භාගයේ එක් එක් බ්ලොක් එකේ නාලිකා ගණන
78input\_block\_channels=[channels]
එක් එක් මට්ටමේ නාලිකා ගණන
80channels\_list=[channels\*mforminchannel\_multipliers]
මට්ටම් සකස් කරන්න
82foriinrange(levels):
අවශේෂ කුට්ටි සහ අවධානය එක් කරන්න
84for\_inrange(n\_res\_blocks):
පෙර නාලිකා සංඛ්යාවේ සිට වර්තමාන මට්ටමේ නාලිකා ගණන දක්වා අවශේෂ බ්ලොක් සිතියම්
87layers=[ResBlock(channels,d\_time\_emb,out\_channels=channels\_list[i])]88channels=channels\_list[i]
ට්රාන්ස්ෆෝමර් එකතු කරන්න
90ifiinattention\_levels:91layers.append(SpatialTransformer(channels,n\_heads,tf\_layers,d\_cond))
යූ-නෙට් හි ආදාන භාගයට ඒවා එක් කර එහි ප්රතිදානයේ නාලිකා ගණන නිරීක්ෂණය කරන්න
94self.input\_blocks.append(TimestepEmbedSequential(\*layers))95input\_block\_channels.append(channels)
අවසාන වශයෙන් හැර අනෙක් සියලුම මට්ටම්වල පහළ නියැදිය
97ifi!=levels-1:98self.input\_blocks.append(TimestepEmbedSequential(DownSample(channels)))99input\_block\_channels.append(channels)
යූ-නෙට් මැද
102self.middle\_block=TimestepEmbedSequential(103ResBlock(channels,d\_time\_emb),104SpatialTransformer(channels,n\_heads,tf\_layers,d\_cond),105ResBlock(channels,d\_time\_emb),106)
යූ-නෙට් හි දෙවන භාගය
109self.output\_blocks=nn.ModuleList([])
ප්රතිලෝම අනුපිළිවෙලින් මට්ටම් සකස් කරන්න
111foriinreversed(range(levels)):
අවශේෂ කුට්ටි සහ අවධානය එක් කරන්න
113forjinrange(n\_res\_blocks+1):
පෙර නාලිකා සංඛ්යාවෙන් අවශේෂ බ්ලොක් සිතියම් සහ යූ-නෙට් හි ආදාන භාගයේ සිට වත්මන් මට්ටමේ නාලිකා ගණන දක්වා මඟ හැරීමේ සම්බන්ධතා.
117layers=[ResBlock(channels+input\_block\_channels.pop(),d\_time\_emb,out\_channels=channels\_list[i])]118channels=channels\_list[i]
ට්රාන්ස්ෆෝමර් එකතු කරන්න
120ifiinattention\_levels:121layers.append(SpatialTransformer(channels,n\_heads,tf\_layers,d\_cond))
අන්තිම අවශේෂ කොටස හැර අවසාන අවශේෂ කොටසින් පසු සෑම මට්ටමකම ඉහළට නියැදිය. අපි ආපසු හැරවීමට පුනරාවර්තනය කරන බව සලකන්න; i.e. අවසානi == 0 වේ.
125ifi!=0andj==n\_res\_blocks:126layers.append(UpSample(channels))
යූ-නෙට් හි ප්රතිදාන භාගයට එක් කරන්න
128self.output\_blocks.append(TimestepEmbedSequential(\*layers))
අවසාන සාමාන්යකරණය සහ3×3 කැටි කිරීම
131self.out=nn.Sequential(132normalization(channels),133nn.SiLU(),134nn.Conv2d(channels,out\_channels,3,padding=1),135)
time_steps හැඩයේ කාල පියවර වේ[batch_size]max_period කාවැද්දීම් වල අවම සංඛ්යාතය පාලනය කරයි.137deftime\_step\_embedding(self,time\_steps:torch.Tensor,max\_period:int=10000):
2c; නාලිකා අඩක් පාපය වන අතර අනෙක් භාගය කෝස් වේ,
145half=self.channels//2
10000c2i1
147frequencies=torch.exp(148-math.log(max\_period)\*torch.arange(start=0,end=half,dtype=torch.float32)/half149).to(device=time\_steps.device)
10000c2it
151args=time\_steps[:,None].float()\*frequencies[None]
cos(10000c2it)සහsin(10000c2it)
153returntorch.cat([torch.cos(args),torch.sin(args)],dim=-1)
x හැඩයේ ආදාන විශේෂාංග සිතියමයි[batch_size, channels, width, height]time_steps හැඩයේ කාල පියවර වේ[batch_size]cond හැඩයේ කන්ඩිෂනේෂන්[batch_size, n_cond, d_cond]155defforward(self,x:torch.Tensor,time\_steps:torch.Tensor,cond:torch.Tensor):
මඟ හැරීමේ සම්බන්ධතා සඳහා ආදාන අර්ධ ප්රතිදානයන් ගබඩා කිරීම
162x\_input\_block=[]
කාලය පියවර කාවැද්දීම් ලබා ගන්න
165t\_emb=self.time\_step\_embedding(time\_steps)166t\_emb=self.time\_embed(t\_emb)
U-Net ආදාන අඩක්
169formoduleinself.input\_blocks:170x=module(x,t\_emb,cond)171x\_input\_block.append(x)
යූ-නෙට් මැද
173x=self.middle\_block(x,t\_emb,cond)
U-Net ප්රතිදාන අඩක්
175formoduleinself.output\_blocks:176x=torch.cat([x,x\_input\_block.pop()],dim=1)177x=module(x,t\_emb,cond)
අවසාන සාමාන්යකරණය සහ3×3 කැටි කිරීම
180returnself.out(x)
මෙම අනුක්රමික මොඩියුලයට විවිධ මොඩියුලයන් උරා බොනnn.Conv``SpatialTransformer අතර ගැලපෙන අත්සන් සමඟ ඒවා අමතන්නResBlock
183classTimestepEmbedSequential(nn.Sequential):
191defforward(self,x,t\_emb,cond=None):192forlayerinself:193ifisinstance(layer,ResBlock):194x=layer(x,t\_emb)195elifisinstance(layer,SpatialTransformer):196x=layer(x,cond)197else:198x=layer(x)199returnx
202classUpSample(nn.Module):
channels යනු නාලිකා ගණන207def\_\_init\_\_(self,channels:int):
211super().\_\_init\_\_()
3×3කැටි ගැසීමේ සිතියම්කරණය
213self.conv=nn.Conv2d(channels,channels,3,padding=1)
x හැඩය සහිත ආදාන විශේෂාංග සිතියමයි[batch_size, channels, height, width]215defforward(self,x:torch.Tensor):
සාධකයක් අනුව ඉහළ නියැදිය2
220x=F.interpolate(x,scale\_factor=2,mode="nearest")
කැටි ගැසිම යොදන්න
222returnself.conv(x)
225classDownSample(nn.Module):
channels යනු නාලිකා ගණන230def\_\_init\_\_(self,channels:int):
234super().\_\_init\_\_()
3×3ක සාධකයක් විසින් පහළ-නියැදි2 කිරීමට stride දිග සමග convolution2
236self.op=nn.Conv2d(channels,channels,3,stride=2,padding=1)
x හැඩය සහිත ආදාන විශේෂාංග සිතියමයි[batch_size, channels, height, width]238defforward(self,x:torch.Tensor):
කැටි ගැසිම යොදන්න
243returnself.op(x)
246classResBlock(nn.Module):
channels ආදාන නාලිකා ගණනd_t_emb කාලරාමු කාවැද්දීම් වල ප්රමාණයout_channels පිටතට ඇති නාලිකා ගණන වේ. `නාලිකාවලට පෙරනිමි.251def\_\_init\_\_(self,channels:int,d\_t\_emb:int,\*,out\_channels=None):
257super().\_\_init\_\_()
out_channels නිශ්චිතව දක්වා නැත
259ifout\_channelsisNone:260out\_channels=channels
පළමු සාමාන්යකරණය සහ කැටි ගැසිම
263self.in\_layers=nn.Sequential(264normalization(channels),265nn.SiLU(),266nn.Conv2d(channels,out\_channels,3,padding=1),267)
කාල පියවර කාවැද්දීම්
270self.emb\_layers=nn.Sequential(271nn.SiLU(),272nn.Linear(d\_t\_emb,out\_channels),273)
අවසාන කැටි ගැසුණු ස්ථරය
275self.out\_layers=nn.Sequential(276normalization(out\_channels),277nn.SiLU(),278nn.Dropout(0.),279nn.Conv2d(out\_channels,out\_channels,3,padding=1)280)
channels අවශේෂ සම්බන්ධතාවය සඳහා ස්තරයout_channels සිතියම්ගත කිරීම
283ifout\_channels==channels:284self.skip\_connection=nn.Identity()285else:286self.skip\_connection=nn.Conv2d(channels,out\_channels,1)
x හැඩය සහිත ආදාන විශේෂාංග සිතියමයි[batch_size, channels, height, width]
t_emb හැඩයේ කාල පියවර කාවැද්දීම් වේ[batch_size, d_t_emb]
288defforward(self,x:torch.Tensor,t\_emb:torch.Tensor):
මූලික කැටි ගැසිම
294h=self.in\_layers(x)
කාල පියවර කාවැද්දීම්
296t\_emb=self.emb\_layers(t\_emb).type(h.dtype)
කාල පියවර කාවැද්දීම් එකතු කරන්න
298h=h+t\_emb[:,:,None,None]
අවසාන කැටි ගැසිම
300h=self.out\_layers(h)
මඟ හැරීමේ සම්බන්ධතාවය එක් කරන්න
302returnself.skip\_connection(x)+h
305classGroupNorm32(nn.GroupNorm):
310defforward(self,x):311returnsuper().forward(x.float()).type(x.dtype)
මෙය උපකාරක ශ්රිතයක් වන අතර ස්ථාවර කණ්ඩායම් සංඛ්යාවක් ඇත..
314defnormalization(channels):
320returnGroupNorm32(32,channels)
සයිනොසොයිඩල් කාල පියවර කාවැද්දීම් පරීක්ෂා කරන්න
323def\_test\_time\_embeddings():
327importmatplotlib.pyplotasplt328329plt.figure(figsize=(15,5))330m=UNetModel(in\_channels=1,out\_channels=1,channels=320,n\_res\_blocks=1,attention\_levels=[],331channel\_multipliers=[],332n\_heads=1,tf\_layers=1,d\_cond=1)333te=m.time\_step\_embedding(torch.arange(0,1000))334plt.plot(np.arange(1000),te[:,[50,100,190,260]].numpy())335plt.legend(["dim %d"%pforpin[50,100,190,260]])336plt.title("Time embeddings")337plt.show()
341if\_\_name\_\_=='\_\_main\_\_':342\_test\_time\_embeddings()