Back to Annotated Deep Learning Paper Implementations

ස්ථාවර විසරණය සඳහා යූ-නෙට්

docs/si/diffusion/stable_diffusion/model/unet.html

latest16.0 KB
Original Source

homediffusionstable_diffusionmodel

View code on Github

#

ස්ථාවර විසරණය සඳහා යූ-නෙට්

මෙය ලබා දෙන යූ-නෙට් ක්රියාත්මක කරයිϵcond​(xt​,c)

අපි ආදර්ශ අර්ථ දැක්වීම තබා ඇති අතර කොම්විස්/ස්ථාවර විසරණ සිට නොවෙනස්ව නම් කිරීම අපට මුරපොලවල් කෙලින්ම පැටවිය හැකි වන පරිදි.

18importmath19fromtypingimportList2021importnumpyasnp22importtorch23importtorch.nnasnn24importtorch.nn.functionalasF2526fromlabml\_nn.diffusion.stable\_diffusion.model.unet\_attentionimportSpatialTransformer

#

යූ-නෙට් ආකෘතිය

29classUNetModel(nn.Module):

#

  • in_channels ආදාන විශේෂාංග සිතියමේ නාලිකා ගණන වේ
  • out_channels ප්රතිදාන විශේෂාංග සිතියමේ නාලිකා ගණන වේ
  • channels ආකෘතිය සඳහා මූලික නාලිකා ගණන වේ
  • n_res_blocks එක් එක් මට්ටමේ අවශේෂ කුට්ටි ගණන
  • attention_levels අවධානය යොමු කළ යුතු මට්ටම් වේ
  • channel_multipliers එක් එක් මට්ටම් සඳහා නාලිකා ගණන සඳහා බහුකාර්ය සාධක වේ
  • n_heads ට්රාන්ස්ෆෝමර්වල අවධානය යොමු කිරීමේ හිස් සංඛ්යාව
34def\_\_init\_\_(35self,\*,36in\_channels:int,37out\_channels:int,38channels:int,39n\_res\_blocks:int,40attention\_levels:List[int],41channel\_multipliers:List[int],42n\_heads:int,43tf\_layers:int=1,44d\_cond:int=768):

#

54super().\_\_init\_\_()55self.channels=channels

#

මට්ටම් ගණන

58levels=len(channel\_multipliers)

#

ප්රමාණ කාල කාවැද්දීම්

60d\_time\_emb=channels\*461self.time\_embed=nn.Sequential(62nn.Linear(channels,d\_time\_emb),63nn.SiLU(),64nn.Linear(d\_time\_emb,d\_time\_emb),65)

#

U-Net ආදාන අඩක්

68self.input\_blocks=nn.ModuleList()

#

ආදානය සිතියම්3×3 ගත කරන මූලික ව්යාවච්ඡාවchannels . විවිධ මොඩියුලවල විවිධ ඉදිරි ක්රියාකාරී අත්සන් ඇති බැවින් කුට්ටිTimestepEmbedSequential මොඩියුලයේ ඔතා ඇත; නිදසුනක් ලෙස, කැටි ගැසීමේදී විශේෂාංග සිතියම පමණක් පිළිගන්නා අතර අවශේෂ කොටස් විශේෂාංග සිතියම සහ වේලාව කාවැද්දීම පිළිගනී. TimestepEmbedSequential ඒ අනුව ඔවුන් අමතයි.

75self.input\_blocks.append(TimestepEmbedSequential(76nn.Conv2d(in\_channels,channels,3,padding=1)))

#

යූ-නෙට් හි ආදාන භාගයේ එක් එක් බ්ලොක් එකේ නාලිකා ගණන

78input\_block\_channels=[channels]

#

එක් එක් මට්ටමේ නාලිකා ගණන

80channels\_list=[channels\*mforminchannel\_multipliers]

#

මට්ටම් සකස් කරන්න

82foriinrange(levels):

#

අවශේෂ කුට්ටි සහ අවධානය එක් කරන්න

84for\_inrange(n\_res\_blocks):

#

පෙර නාලිකා සංඛ්යාවේ සිට වර්තමාන මට්ටමේ නාලිකා ගණන දක්වා අවශේෂ බ්ලොක් සිතියම්

87layers=[ResBlock(channels,d\_time\_emb,out\_channels=channels\_list[i])]88channels=channels\_list[i]

#

ට්රාන්ස්ෆෝමර් එකතු කරන්න

90ifiinattention\_levels:91layers.append(SpatialTransformer(channels,n\_heads,tf\_layers,d\_cond))

#

යූ-නෙට් හි ආදාන භාගයට ඒවා එක් කර එහි ප්රතිදානයේ නාලිකා ගණන නිරීක්ෂණය කරන්න

94self.input\_blocks.append(TimestepEmbedSequential(\*layers))95input\_block\_channels.append(channels)

#

අවසාන වශයෙන් හැර අනෙක් සියලුම මට්ටම්වල පහළ නියැදිය

97ifi!=levels-1:98self.input\_blocks.append(TimestepEmbedSequential(DownSample(channels)))99input\_block\_channels.append(channels)

#

යූ-නෙට් මැද

102self.middle\_block=TimestepEmbedSequential(103ResBlock(channels,d\_time\_emb),104SpatialTransformer(channels,n\_heads,tf\_layers,d\_cond),105ResBlock(channels,d\_time\_emb),106)

#

යූ-නෙට් හි දෙවන භාගය

109self.output\_blocks=nn.ModuleList([])

#

ප්රතිලෝම අනුපිළිවෙලින් මට්ටම් සකස් කරන්න

111foriinreversed(range(levels)):

#

අවශේෂ කුට්ටි සහ අවධානය එක් කරන්න

113forjinrange(n\_res\_blocks+1):

#

පෙර නාලිකා සංඛ්යාවෙන් අවශේෂ බ්ලොක් සිතියම් සහ යූ-නෙට් හි ආදාන භාගයේ සිට වත්මන් මට්ටමේ නාලිකා ගණන දක්වා මඟ හැරීමේ සම්බන්ධතා.

117layers=[ResBlock(channels+input\_block\_channels.pop(),d\_time\_emb,out\_channels=channels\_list[i])]118channels=channels\_list[i]

#

ට්රාන්ස්ෆෝමර් එකතු කරන්න

120ifiinattention\_levels:121layers.append(SpatialTransformer(channels,n\_heads,tf\_layers,d\_cond))

#

අන්තිම අවශේෂ කොටස හැර අවසාන අවශේෂ කොටසින් පසු සෑම මට්ටමකම ඉහළට නියැදිය. අපි ආපසු හැරවීමට පුනරාවර්තනය කරන බව සලකන්න; i.e. අවසානi == 0 වේ.

125ifi!=0andj==n\_res\_blocks:126layers.append(UpSample(channels))

#

යූ-නෙට් හි ප්රතිදාන භාගයට එක් කරන්න

128self.output\_blocks.append(TimestepEmbedSequential(\*layers))

#

අවසාන සාමාන්යකරණය සහ3×3 කැටි කිරීම

131self.out=nn.Sequential(132normalization(channels),133nn.SiLU(),134nn.Conv2d(channels,out\_channels,3,padding=1),135)

#

සයිනොසොයිඩල් කාල පියවර කාවැද්දීම් සාදන්න

  • time_steps හැඩයේ කාල පියවර වේ[batch_size]
  • max_period කාවැද්දීම් වල අවම සංඛ්යාතය පාලනය කරයි.
137deftime\_step\_embedding(self,time\_steps:torch.Tensor,max\_period:int=10000):

#

2c​; නාලිකා අඩක් පාපය වන අතර අනෙක් භාගය කෝස් වේ,

145half=self.channels//2

#

10000c2i​1​

147frequencies=torch.exp(148-math.log(max\_period)\*torch.arange(start=0,end=half,dtype=torch.float32)/half149).to(device=time\_steps.device)

#

10000c2i​t​

151args=time\_steps[:,None].float()\*frequencies[None]

#

cos(10000c2i​t​)සහsin(10000c2i​t​)

153returntorch.cat([torch.cos(args),torch.sin(args)],dim=-1)

#

  • x හැඩයේ ආදාන විශේෂාංග සිතියමයි[batch_size, channels, width, height]
  • time_steps හැඩයේ කාල පියවර වේ[batch_size]
  • cond හැඩයේ කන්ඩිෂනේෂන්[batch_size, n_cond, d_cond]
155defforward(self,x:torch.Tensor,time\_steps:torch.Tensor,cond:torch.Tensor):

#

මඟ හැරීමේ සම්බන්ධතා සඳහා ආදාන අර්ධ ප්රතිදානයන් ගබඩා කිරීම

162x\_input\_block=[]

#

කාලය පියවර කාවැද්දීම් ලබා ගන්න

165t\_emb=self.time\_step\_embedding(time\_steps)166t\_emb=self.time\_embed(t\_emb)

#

U-Net ආදාන අඩක්

169formoduleinself.input\_blocks:170x=module(x,t\_emb,cond)171x\_input\_block.append(x)

#

යූ-නෙට් මැද

173x=self.middle\_block(x,t\_emb,cond)

#

U-Net ප්රතිදාන අඩක්

175formoduleinself.output\_blocks:176x=torch.cat([x,x\_input\_block.pop()],dim=1)177x=module(x,t\_emb,cond)

#

අවසාන සාමාන්යකරණය සහ3×3 කැටි කිරීම

180returnself.out(x)

#

විවිධ යෙදවුම් සහිත මොඩියුල සඳහා අනුක්රමික කොටස

මෙම අනුක්රමික මොඩියුලයට විවිධ මොඩියුලයන් උරා බොනnn.Conv``SpatialTransformer අතර ගැලපෙන අත්සන් සමඟ ඒවා අමතන්නResBlock

183classTimestepEmbedSequential(nn.Sequential):

#

191defforward(self,x,t\_emb,cond=None):192forlayerinself:193ifisinstance(layer,ResBlock):194x=layer(x,t\_emb)195elifisinstance(layer,SpatialTransformer):196x=layer(x,cond)197else:198x=layer(x)199returnx

#

දක්වා-නියැදීම් ස්ථරය

202classUpSample(nn.Module):

#

  • channels යනු නාලිකා ගණන
207def\_\_init\_\_(self,channels:int):

#

211super().\_\_init\_\_()

#

3×3කැටි ගැසීමේ සිතියම්කරණය

213self.conv=nn.Conv2d(channels,channels,3,padding=1)

#

  • x හැඩය සහිත ආදාන විශේෂාංග සිතියමයි[batch_size, channels, height, width]
215defforward(self,x:torch.Tensor):

#

සාධකයක් අනුව ඉහළ නියැදිය2

220x=F.interpolate(x,scale\_factor=2,mode="nearest")

#

කැටි ගැසිම යොදන්න

222returnself.conv(x)

#

පහළ-නියැදි ස්ථරය

225classDownSample(nn.Module):

#

  • channels යනු නාලිකා ගණන
230def\_\_init\_\_(self,channels:int):

#

234super().\_\_init\_\_()

#

3×3ක සාධකයක් විසින් පහළ-නියැදි2 කිරීමට stride දිග සමග convolution2

236self.op=nn.Conv2d(channels,channels,3,stride=2,padding=1)

#

  • x හැඩය සහිත ආදාන විශේෂාංග සිතියමයි[batch_size, channels, height, width]
238defforward(self,x:torch.Tensor):

#

කැටි ගැසිම යොදන්න

243returnself.op(x)

#

රෙස්නෙට් බ්ලොක්

246classResBlock(nn.Module):

#

  • channels ආදාන නාලිකා ගණන
  • d_t_emb කාලරාමු කාවැද්දීම් වල ප්රමාණය
  • out_channels පිටතට ඇති නාලිකා ගණන වේ. `නාලිකාවලට පෙරනිමි.
251def\_\_init\_\_(self,channels:int,d\_t\_emb:int,\*,out\_channels=None):

#

257super().\_\_init\_\_()

#

out_channels නිශ්චිතව දක්වා නැත

259ifout\_channelsisNone:260out\_channels=channels

#

පළමු සාමාන්යකරණය සහ කැටි ගැසිම

263self.in\_layers=nn.Sequential(264normalization(channels),265nn.SiLU(),266nn.Conv2d(channels,out\_channels,3,padding=1),267)

#

කාල පියවර කාවැද්දීම්

270self.emb\_layers=nn.Sequential(271nn.SiLU(),272nn.Linear(d\_t\_emb,out\_channels),273)

#

අවසාන කැටි ගැසුණු ස්ථරය

275self.out\_layers=nn.Sequential(276normalization(out\_channels),277nn.SiLU(),278nn.Dropout(0.),279nn.Conv2d(out\_channels,out\_channels,3,padding=1)280)

#

channels අවශේෂ සම්බන්ධතාවය සඳහා ස්තරයout_channels සිතියම්ගත කිරීම

283ifout\_channels==channels:284self.skip\_connection=nn.Identity()285else:286self.skip\_connection=nn.Conv2d(channels,out\_channels,1)

#

  • x හැඩය සහිත ආදාන විශේෂාංග සිතියමයි[batch_size, channels, height, width]

  • t_emb හැඩයේ කාල පියවර කාවැද්දීම් වේ[batch_size, d_t_emb]

288defforward(self,x:torch.Tensor,t\_emb:torch.Tensor):

#

මූලික කැටි ගැසිම

294h=self.in\_layers(x)

#

කාල පියවර කාවැද්දීම්

296t\_emb=self.emb\_layers(t\_emb).type(h.dtype)

#

කාල පියවර කාවැද්දීම් එකතු කරන්න

298h=h+t\_emb[:,:,None,None]

#

අවසාන කැටි ගැසිම

300h=self.out\_layers(h)

#

මඟ හැරීමේ සම්බන්ධතාවය එක් කරන්න

302returnself.skip\_connection(x)+h

#

float32 වාත්තු සමග කණ්ඩායම් සාමාන්යකරණය

305classGroupNorm32(nn.GroupNorm):

#

310defforward(self,x):311returnsuper().forward(x.float()).type(x.dtype)

#

කණ්ඩායම් සාමාන්යකරණය

මෙය උපකාරක ශ්රිතයක් වන අතර ස්ථාවර කණ්ඩායම් සංඛ්යාවක් ඇත..

314defnormalization(channels):

#

320returnGroupNorm32(32,channels)

#

සයිනොසොයිඩල් කාල පියවර කාවැද්දීම් පරීක්ෂා කරන්න

323def\_test\_time\_embeddings():

#

327importmatplotlib.pyplotasplt328329plt.figure(figsize=(15,5))330m=UNetModel(in\_channels=1,out\_channels=1,channels=320,n\_res\_blocks=1,attention\_levels=[],331channel\_multipliers=[],332n\_heads=1,tf\_layers=1,d\_cond=1)333te=m.time\_step\_embedding(torch.arange(0,1000))334plt.plot(np.arange(1000),te[:,[50,100,190,260]].numpy())335plt.legend(["dim %d"%pforpin[50,100,190,260]])336plt.title("Time embeddings")337plt.show()

#

341if\_\_name\_\_=='\_\_main\_\_':342\_test\_time\_embeddings()

Trending Research Paperslabml.ai