Back to Annotated Deep Learning Paper Implementations

ස්ථාවර විසරණය සඳහා ස්වයංක්රීය ආකේතකය

docs/si/diffusion/stable_diffusion/model/autoencoder.html

latest20.6 KB
Original Source

homediffusionstable_diffusionmodel

View code on Github

#

ස්ථාවර විසරණය සඳහා ස්වයංක්රීය ආකේතකය

රූප අවකාශය සහ ගුප්ත අවකාශය අතර සිතියම් ගත කිරීම සඳහා භාවිතා කරන ස්වයංක්රීය එන්කෝඩර් ආකෘතිය මෙය ක්රියාත්මක කරයි.

අපි ආදර්ශ අර්ථ දැක්වීම තබා ඇති අතර කොම්විස්/ස්ථාවර විසරණ සිට නොවෙනස්ව නම් කිරීම අපට මුරපොලවල් කෙලින්ම පැටවිය හැකි වන පරිදි.

18fromtypingimportList1920importtorch21importtorch.nn.functionalasF22fromtorchimportnn

#

ඔටෝඑන්කෝඩරය

මෙය එන්කෝඩරය සහ විකේතක මොඩියුල වලින් සමන්විත වේ.

25classAutoencoder(nn.Module):

#

  • encoder ආකේතකය වේ

  • decoder විකේතකය වේ

  • emb_channels යනු ප්රමාණාත්මක කාවැද්දීමේ අවකාශයේ මානයන් ගණන

  • z_channels කාවැද්දීමේ අවකාශයේ නාලිකා ගණන වේ

32def\_\_init\_\_(self,encoder:'Encoder',decoder:'Decoder',emb\_channels:int,z\_channels:int):

#

39super().\_\_init\_\_()40self.encoder=encoder41self.decoder=decoder

#

අවකාශය කාවැද්දීමේ සිට ප්රමාණාත්මක කාවැද්දීමේ අවකාශ අවස්ථා දක්වා සිතියම් ගත කිරීම (මධ්යන්ය සහ ලොග් විචලතාව)

44self.quant\_conv=nn.Conv2d(2\*z\_channels,2\*emb\_channels,1)

#

ප්රමාණාත්මක කාවැද්දීමේ අවකාශයේ සිට නැවත කාවැද්දීම අවකාශය දක්වා සිතියමට සංකෝචනය

47self.post\_quant\_conv=nn.Conv2d(emb\_channels,z\_channels,1)

#

ගුප්ත නිරූපණයට රූප කේතනය කරන්න

  • img හැඩය සහිත රූප ටෙන්සරයයි[batch_size, img_channels, img_height, img_width]
49defencode(self,img:torch.Tensor)-\>'GaussianDistribution':

#

හැඩය සහිත කාවැද්දීම් ලබා ගන්න[batch_size, z_channels * 2, z_height, z_height]

56z=self.encoder(img)

#

ප්රමාණාත්මක කාවැද්දීමේ අවකාශයේ මොහොත ලබා ගන්න

58moments=self.quant\_conv(z)

#

බෙදා හැරීම ආපසු ලබා දෙන්න

60returnGaussianDistribution(moments)

#

ගුප්ත නිරූපණයෙන් රූප විකේතනය කරන්න

  • z හැඩය සහිත ගුප්ත නිරූපණයයි[batch_size, emb_channels, z_height, z_height]
62defdecode(self,z:torch.Tensor):

#

ප්රමාණාත්මක නිරූපණයෙන් අවකාශය කාවැද්දීම සඳහා සිතියම

69z=self.post\_quant\_conv(z)

#

හැඩයේ රූපය විකේතනය කරන්න[batch_size, channels, height, width]

71returnself.decoder(z)

#

එන්කෝඩර් මොඩියුලය

74classEncoder(nn.Module):

#

  • channels පළමු සංවහන ස්ථරයේ නාලිකා ගණන වේ

  • channel_multipliers පසුකාලීන බ්ලොක් වල නාලිකා සංඛ්යාව සඳහා බහුකාර්ය සාධක වේ

  • n_resnet_blocks එක් එක් විභේදනයේ රෙස්නෙට් ස්ථර ගණන වේ

  • in_channels යනු රූපයේ ඇති නාලිකා ගණන

  • z_channels කාවැද්දීමේ අවකාශයේ නාලිකා ගණන වේ

79def\_\_init\_\_(self,\*,channels:int,channel\_multipliers:List[int],n\_resnet\_blocks:int,80in\_channels:int,z\_channels:int):

#

89super().\_\_init\_\_()

#

විවිධ විභේදන වල කුට්ටි ගණන. එක් එක් ඉහළ මට්ටමේ කොටස අවසානයේ විභේදනය අඩකින් යුක්ත වේ

93n\_resolutions=len(channel\_multipliers)

#

රූපය සිතියම්3×3 ගත කරන මූලික කැටි ගැසුණු ස්ථරයchannels

96self.conv\_in=nn.Conv2d(in\_channels,channels,3,stride=1,padding=1)

#

එක් එක් ඉහළ මට්ටමේ බ්ලොක් එකේ නාලිකා ගණන

99channels\_list=[m\*channelsformin[1]+channel\_multipliers]

#

ඉහළ මට්ටමේ කුට්ටි ලැයිස්තුව

102self.down=nn.ModuleList()

#

ඉහළ මට්ටමේ කුට්ටි සාදන්න

104foriinrange(n\_resolutions):

#

සෑම ඉහළ මට්ටමේ බ්ලොක් එකක්ම බහු රෙස්නෙට් බ්ලොක් සහ පහළ-නියැදීම් වලින් සමන්විත වේ

106resnet\_blocks=nn.ModuleList()

#

රෙස්නෙට් බ්ලොක් එකතු කරන්න

108for\_inrange(n\_resnet\_blocks):109resnet\_blocks.append(ResnetBlock(channels,channels\_list[i+1]))110channels=channels\_list[i+1]

#

ඉහළ මට්ටමේ බ්ලොක්

112down=nn.Module()113down.block=resnet\_blocks

#

අන්තිම හැර එක් එක් ඉහළ මට්ටමේ කොටස අවසානයේ පහළ-නියැදීම

115ifi!=n\_resolutions-1:116down.downsample=DownSample(channels)117else:118down.downsample=nn.Identity()

#

120self.down.append(down)

#

අවධානය යොමු කරන අවසාන රෙස්නෙට් බ්ලොක්

123self.mid=nn.Module()124self.mid.block\_1=ResnetBlock(channels,channels)125self.mid.attn\_1=AttnBlock(channels)126self.mid.block\_2=ResnetBlock(channels,channels)

#

3×3කැටි ගැසීමකින් අවකාශය කාවැද්දීම සඳහා සිතියම

129self.norm\_out=normalization(channels)130self.conv\_out=nn.Conv2d(channels,2\*z\_channels,3,stride=1,padding=1)

#

  • img හැඩය සහිත රූප ටෙන්සරයයි[batch_size, img_channels, img_height, img_width]
132defforward(self,img:torch.Tensor):

#

ආරම්භක කැටි ගැස්මchannels සමඟ සිතියම

138x=self.conv\_in(img)

#

ඉහළ මට්ටමේ කුට්ටි

141fordowninself.down:

#

රෙස්නෙට් බ්ලොක්

143forblockindown.block:144x=block(x)

#

පහළ-නියැදීම්

146x=down.downsample(x)

#

අවධානය යොමු කරන අවසාන රෙස්නෙට් බ්ලොක්

149x=self.mid.block\_1(x)150x=self.mid.attn\_1(x)151x=self.mid.block\_2(x)

#

අවකාශය කාවැද්දීම සඳහා සාමාන්යකරණය කර සිතියම් ගත කරන්න

154x=self.norm\_out(x)155x=swish(x)156x=self.conv\_out(x)

#

159returnx

#

විකේතක මොඩියුලය

162classDecoder(nn.Module):

#

  • channels අවසාන සංවහන ස්ථරයේ නාලිකා ගණන වේ

  • channel_multipliers පෙර බ්ලොක් වල නාලිකා ගණන සඳහා බහුකාර්ය සාධක, ප්රතිලෝම අනුපිළිවෙල

  • n_resnet_blocks එක් එක් විභේදනයේ රෙස්නෙට් ස්ථර ගණන වේ

  • out_channels යනු රූපයේ ඇති නාලිකා ගණන

  • z_channels කාවැද්දීමේ අවකාශයේ නාලිකා ගණන වේ

167def\_\_init\_\_(self,\*,channels:int,channel\_multipliers:List[int],n\_resnet\_blocks:int,168out\_channels:int,z\_channels:int):

#

177super().\_\_init\_\_()

#

විවිධ විභේදන වල කුට්ටි ගණන. එක් එක් ඉහළ මට්ටමේ කොටස අවසානයේ විභේදනය අඩකින් යුක්ත වේ

181num\_resolutions=len(channel\_multipliers)

#

ප්රතිලෝම අනුපිළිවෙලෙහි එක් එක් ඉහළ මට්ටමේ බ්ලොක් එකේ නාලිකා ගණන

184channels\_list=[m\*channelsforminchannel\_multipliers]

#

ඉහළ මට්ටමේ බ්ලොක් එකේ නාලිකා ගණන

187channels=channels\_list[-1]

#

කාවැද්දීමේ අවකාශය සිතියම් ගත කරන මූලික3×3 කැටි ගැස්වීමේ ස්ථරයchannels

190self.conv\_in=nn.Conv2d(z\_channels,channels,3,stride=1,padding=1)

#

අවධානය සහිත රෙස්නෙට් බ්ලොක්

193self.mid=nn.Module()194self.mid.block\_1=ResnetBlock(channels,channels)195self.mid.attn\_1=AttnBlock(channels)196self.mid.block\_2=ResnetBlock(channels,channels)

#

ඉහළ මට්ටමේ කුට්ටි ලැයිස්තුව

199self.up=nn.ModuleList()

#

ඉහළ මට්ටමේ කුට්ටි සාදන්න

201foriinreversed(range(num\_resolutions)):

#

සෑම ඉහළ මට්ටමේ බ්ලොක් එකක්ම බහු රෙස්නෙට් බ්ලොක් සහ ඉහළ නියැදීම් වලින් සමන්විත වේ

203resnet\_blocks=nn.ModuleList()

#

රෙස්නෙට් බ්ලොක් එකතු කරන්න

205for\_inrange(n\_resnet\_blocks+1):206resnet\_blocks.append(ResnetBlock(channels,channels\_list[i]))207channels=channels\_list[i]

#

ඉහළ මට්ටමේ බ්ලොක්

209up=nn.Module()210up.block=resnet\_blocks

#

පළමුවැන්න හැර එක් එක් ඉහළ මට්ටමේ කොටස අවසානයේ ඉහළට නියැදීම

212ifi!=0:213up.upsample=UpSample(channels)214else:215up.upsample=nn.Identity()

#

මුරපොලට අනුකූල වීමට සූදානම් වන්න

217self.up.insert(0,up)

#

3×3සංකෝචනය සමඟ රූප අවකාශයට සිතියම

220self.norm\_out=normalization(channels)221self.conv\_out=nn.Conv2d(channels,out\_channels,3,stride=1,padding=1)

#

  • z හැඩය සහිත කාවැද්දීම tensor වේ[batch_size, z_channels, z_height, z_height]
223defforward(self,z:torch.Tensor):

#

ආරම්භක කැටි ගැස්මchannels සමඟ සිතියම

229h=self.conv\_in(z)

#

අවධානය සහිත රෙස්නෙට් බ්ලොක්

232h=self.mid.block\_1(h)233h=self.mid.attn\_1(h)234h=self.mid.block\_2(h)

#

ඉහළ මට්ටමේ කුට්ටි

237forupinreversed(self.up):

#

රෙස්නෙට් බ්ලොක්

239forblockinup.block:240h=block(h)

#

ඉහළට නියැදීම

242h=up.upsample(h)

#

රූප අවකාශයට සාමාන්යකරණය කර සිතියම් ගත කරන්න

245h=self.norm\_out(h)246h=swish(h)247img=self.conv\_out(h)

#

250returnimg

#

ගවුසියානු බෙදාහැරීම්

253classGaussianDistribution:

#

  • parameters හැඩයේ කාවැද්දීම පිළිබඳ විචලනයන්ගේ මාධ්යයන් සහ ලොග් වේ[batch_size, z_channels * 2, z_height, z_height]
258def\_\_init\_\_(self,parameters:torch.Tensor):

#

භේදය මධ්යන්යය සහ විචලතාව ලඝු-සටහන

264self.mean,log\_var=torch.chunk(parameters,2,dim=1)

#

විචල්යයන්ගේ ලොග් දැමීම

266self.log\_var=torch.clamp(log\_var,-30.0,20.0)

#

සම්මත අපගමනය ගණනය කරන්න

268self.std=torch.exp(0.5\*self.log\_var)

#

270defsample(self):

#

බෙදාහැරීමෙන් නියැදිය

272returnself.mean+self.std\*torch.randn\_like(self.std)

#

අවධානය වාරණ

275classAttnBlock(nn.Module):

#

  • channels යනු නාලිකා ගණන
280def\_\_init\_\_(self,channels:int):

#

284super().\_\_init\_\_()

#

කණ්ඩායම් සාමාන්යකරණය

286self.norm=normalization(channels)

#

විමසුම්, යතුර සහ අගය සිතියම්

288self.q=nn.Conv2d(channels,channels,1)289self.k=nn.Conv2d(channels,channels,1)290self.v=nn.Conv2d(channels,channels,1)

#

අවසාන1×1 කැටි ගැසුණු ස්ථරය

292self.proj\_out=nn.Conv2d(channels,channels,1)

#

අවධානය පරිමාණ සාධකය

294self.scale=channels\*\*-0.5

#

  • x හැඩයේ ආතතිකය වේ[batch_size, channels, height, width]
296defforward(self,x:torch.Tensor):

#

සාමාන්‍ය කරන්නx

301x\_norm=self.norm(x)

#

විමසුම, යතුර සහ දෛශික කාවැද්දීම් ලබා ගන්න

303q=self.q(x\_norm)304k=self.k(x\_norm)305v=self.v(x\_norm)

#

විමසුම, ප්රධාන සහ දෛශික කාවැද්දීම්[batch_size, channels, height, width] වෙත නැවත සකස් කරන්න[batch_size, channels, height * width]

310b,c,h,w=q.shape311q=q.view(b,c,h\*w)312k=k.view(b,c,h\*w)313v=v.view(b,c,h\*w)

#

ගණනය කරන්නseqsoftmax​(dkey​​QK⊤​)

316attn=torch.einsum('bci,bcj-\>bij',q,k)\*self.scale317attn=F.softmax(attn,dim=2)

#

ගණනය කරන්නseqsoftmax​(dkey​​QK⊤​)V

320out=torch.einsum('bij,bcj-\>bci',attn,v)

#

නැවත සකස් කරන්න[batch_size, channels, height, width]

323out=out.view(b,c,h,w)

#

අවසාන1×1 කැටි ගැසුණු ස්ථරය

325out=self.proj\_out(out)

#

අවශේෂ සම්බන්ධතාවය එක් කරන්න

328returnx+out

#

දක්වා-නියැදීම් ස්ථරය

331classUpSample(nn.Module):

#

  • channels යනු නාලිකා ගණන
335def\_\_init\_\_(self,channels:int):

#

339super().\_\_init\_\_()

#

3×3කැටි ගැසීමේ සිතියම්කරණය

341self.conv=nn.Conv2d(channels,channels,3,padding=1)

#

  • x හැඩය සහිත ආදාන විශේෂාංග සිතියමයි[batch_size, channels, height, width]
343defforward(self,x:torch.Tensor):

#

සාධකයක් අනුව ඉහළ නියැදිය2

348x=F.interpolate(x,scale\_factor=2.0,mode="nearest")

#

කැටි ගැසිම යොදන්න

350returnself.conv(x)

#

පහළ-නියැදි ස්ථරය

353classDownSample(nn.Module):

#

  • channels යනු නාලිකා ගණන
357def\_\_init\_\_(self,channels:int):

#

361super().\_\_init\_\_()

#

3×3ක සාධකයක් විසින් පහළ-නියැදි2 කිරීමට stride දිග සමග convolution2

363self.conv=nn.Conv2d(channels,channels,3,stride=2,padding=0)

#

  • x හැඩය සහිත ආදාන විශේෂාංග සිතියමයි[batch_size, channels, height, width]
365defforward(self,x:torch.Tensor):

#

පෑඩින් එකතු කරන්න

370x=F.pad(x,(0,1,0,1),mode="constant",value=0)

#

කැටි ගැසිම යොදන්න

372returnself.conv(x)

#

රෙස්නෙට් බ්ලොක්

375classResnetBlock(nn.Module):

#

  • in_channels යනු ආදානයේ නාලිකා ගණන
  • out_channels නිමැවුමේ නාලිකා ගණන වේ
379def\_\_init\_\_(self,in\_channels:int,out\_channels:int):

#

384super().\_\_init\_\_()

#

පළමු සාමාන්යකරණය සහ කැටි ගැසුණු ස්ථරය

386self.norm1=normalization(in\_channels)387self.conv1=nn.Conv2d(in\_channels,out\_channels,3,stride=1,padding=1)

#

දෙවන සාමාන්යකරණය සහ කැටි ගැසුණු ස්ථරය

389self.norm2=normalization(out\_channels)390self.conv2=nn.Conv2d(out\_channels,out\_channels,3,stride=1,padding=1)

#

in_channels අවශේෂ සම්බන්ධතාවය සඳහා ස්තරයout_channels සිතියම්ගත කිරීම

392ifin\_channels!=out\_channels:393self.nin\_shortcut=nn.Conv2d(in\_channels,out\_channels,1,stride=1,padding=0)394else:395self.nin\_shortcut=nn.Identity()

#

  • x හැඩය සහිත ආදාන විශේෂාංග සිතියමයි[batch_size, channels, height, width]
397defforward(self,x:torch.Tensor):

#

402h=x

#

පළමු සාමාන්යකරණය සහ කැටි ගැසුණු ස්ථරය

405h=self.norm1(h)406h=swish(h)407h=self.conv1(h)

#

දෙවන සාමාන්යකරණය සහ කැටි ගැසුණු ස්ථරය

410h=self.norm2(h)411h=swish(h)412h=self.conv2(h)

#

සිතියම සහ අවශේෂ එකතු කරන්න

415returnself.nin\_shortcut(x)+h

#

ස්විෂ් සක්රිය කිරීම

x⋅σ(x)

418defswish(x:torch.Tensor):

#

424returnx\*torch.sigmoid(x)

#

කණ්ඩායම් සාමාන්යකරණය

මෙය උපකාරක ශ්රිතයක් වන අතර ස්ථාවර කණ්ඩායම් ගණන සහeps .

427defnormalization(channels:int):

#

433returnnn.GroupNorm(num\_groups=32,num\_channels=channels,eps=1e-6)

Trending Research Paperslabml.ai