docs/si/diffusion/stable_diffusion/model/autoencoder.html
homediffusionstable_diffusionmodel
රූප අවකාශය සහ ගුප්ත අවකාශය අතර සිතියම් ගත කිරීම සඳහා භාවිතා කරන ස්වයංක්රීය එන්කෝඩර් ආකෘතිය මෙය ක්රියාත්මක කරයි.
අපි ආදර්ශ අර්ථ දැක්වීම තබා ඇති අතර කොම්විස්/ස්ථාවර විසරණ සිට නොවෙනස්ව නම් කිරීම අපට මුරපොලවල් කෙලින්ම පැටවිය හැකි වන පරිදි.
18fromtypingimportList1920importtorch21importtorch.nn.functionalasF22fromtorchimportnn
මෙය එන්කෝඩරය සහ විකේතක මොඩියුල වලින් සමන්විත වේ.
25classAutoencoder(nn.Module):
encoder ආකේතකය වේ
decoder විකේතකය වේ
emb_channels යනු ප්රමාණාත්මක කාවැද්දීමේ අවකාශයේ මානයන් ගණන
z_channels කාවැද්දීමේ අවකාශයේ නාලිකා ගණන වේ
32def\_\_init\_\_(self,encoder:'Encoder',decoder:'Decoder',emb\_channels:int,z\_channels:int):
39super().\_\_init\_\_()40self.encoder=encoder41self.decoder=decoder
අවකාශය කාවැද්දීමේ සිට ප්රමාණාත්මක කාවැද්දීමේ අවකාශ අවස්ථා දක්වා සිතියම් ගත කිරීම (මධ්යන්ය සහ ලොග් විචලතාව)
44self.quant\_conv=nn.Conv2d(2\*z\_channels,2\*emb\_channels,1)
ප්රමාණාත්මක කාවැද්දීමේ අවකාශයේ සිට නැවත කාවැද්දීම අවකාශය දක්වා සිතියමට සංකෝචනය
47self.post\_quant\_conv=nn.Conv2d(emb\_channels,z\_channels,1)
img හැඩය සහිත රූප ටෙන්සරයයි[batch_size, img_channels, img_height, img_width]49defencode(self,img:torch.Tensor)-\>'GaussianDistribution':
හැඩය සහිත කාවැද්දීම් ලබා ගන්න[batch_size, z_channels * 2, z_height, z_height]
56z=self.encoder(img)
ප්රමාණාත්මක කාවැද්දීමේ අවකාශයේ මොහොත ලබා ගන්න
58moments=self.quant\_conv(z)
බෙදා හැරීම ආපසු ලබා දෙන්න
60returnGaussianDistribution(moments)
z හැඩය සහිත ගුප්ත නිරූපණයයි[batch_size, emb_channels, z_height, z_height]62defdecode(self,z:torch.Tensor):
ප්රමාණාත්මක නිරූපණයෙන් අවකාශය කාවැද්දීම සඳහා සිතියම
69z=self.post\_quant\_conv(z)
හැඩයේ රූපය විකේතනය කරන්න[batch_size, channels, height, width]
71returnself.decoder(z)
74classEncoder(nn.Module):
channels පළමු සංවහන ස්ථරයේ නාලිකා ගණන වේ
channel_multipliers පසුකාලීන බ්ලොක් වල නාලිකා සංඛ්යාව සඳහා බහුකාර්ය සාධක වේ
n_resnet_blocks එක් එක් විභේදනයේ රෙස්නෙට් ස්ථර ගණන වේ
in_channels යනු රූපයේ ඇති නාලිකා ගණන
z_channels කාවැද්දීමේ අවකාශයේ නාලිකා ගණන වේ
79def\_\_init\_\_(self,\*,channels:int,channel\_multipliers:List[int],n\_resnet\_blocks:int,80in\_channels:int,z\_channels:int):
89super().\_\_init\_\_()
විවිධ විභේදන වල කුට්ටි ගණන. එක් එක් ඉහළ මට්ටමේ කොටස අවසානයේ විභේදනය අඩකින් යුක්ත වේ
93n\_resolutions=len(channel\_multipliers)
රූපය සිතියම්3×3 ගත කරන මූලික කැටි ගැසුණු ස්ථරයchannels
96self.conv\_in=nn.Conv2d(in\_channels,channels,3,stride=1,padding=1)
එක් එක් ඉහළ මට්ටමේ බ්ලොක් එකේ නාලිකා ගණන
99channels\_list=[m\*channelsformin[1]+channel\_multipliers]
ඉහළ මට්ටමේ කුට්ටි ලැයිස්තුව
102self.down=nn.ModuleList()
ඉහළ මට්ටමේ කුට්ටි සාදන්න
104foriinrange(n\_resolutions):
සෑම ඉහළ මට්ටමේ බ්ලොක් එකක්ම බහු රෙස්නෙට් බ්ලොක් සහ පහළ-නියැදීම් වලින් සමන්විත වේ
106resnet\_blocks=nn.ModuleList()
රෙස්නෙට් බ්ලොක් එකතු කරන්න
108for\_inrange(n\_resnet\_blocks):109resnet\_blocks.append(ResnetBlock(channels,channels\_list[i+1]))110channels=channels\_list[i+1]
ඉහළ මට්ටමේ බ්ලොක්
112down=nn.Module()113down.block=resnet\_blocks
අන්තිම හැර එක් එක් ඉහළ මට්ටමේ කොටස අවසානයේ පහළ-නියැදීම
115ifi!=n\_resolutions-1:116down.downsample=DownSample(channels)117else:118down.downsample=nn.Identity()
120self.down.append(down)
අවධානය යොමු කරන අවසාන රෙස්නෙට් බ්ලොක්
123self.mid=nn.Module()124self.mid.block\_1=ResnetBlock(channels,channels)125self.mid.attn\_1=AttnBlock(channels)126self.mid.block\_2=ResnetBlock(channels,channels)
3×3කැටි ගැසීමකින් අවකාශය කාවැද්දීම සඳහා සිතියම
129self.norm\_out=normalization(channels)130self.conv\_out=nn.Conv2d(channels,2\*z\_channels,3,stride=1,padding=1)
img හැඩය සහිත රූප ටෙන්සරයයි[batch_size, img_channels, img_height, img_width]132defforward(self,img:torch.Tensor):
ආරම්භක කැටි ගැස්මchannels සමඟ සිතියම
138x=self.conv\_in(img)
ඉහළ මට්ටමේ කුට්ටි
141fordowninself.down:
රෙස්නෙට් බ්ලොක්
143forblockindown.block:144x=block(x)
පහළ-නියැදීම්
146x=down.downsample(x)
අවධානය යොමු කරන අවසාන රෙස්නෙට් බ්ලොක්
149x=self.mid.block\_1(x)150x=self.mid.attn\_1(x)151x=self.mid.block\_2(x)
අවකාශය කාවැද්දීම සඳහා සාමාන්යකරණය කර සිතියම් ගත කරන්න
154x=self.norm\_out(x)155x=swish(x)156x=self.conv\_out(x)
159returnx
162classDecoder(nn.Module):
channels අවසාන සංවහන ස්ථරයේ නාලිකා ගණන වේ
channel_multipliers පෙර බ්ලොක් වල නාලිකා ගණන සඳහා බහුකාර්ය සාධක, ප්රතිලෝම අනුපිළිවෙල
n_resnet_blocks එක් එක් විභේදනයේ රෙස්නෙට් ස්ථර ගණන වේ
out_channels යනු රූපයේ ඇති නාලිකා ගණන
z_channels කාවැද්දීමේ අවකාශයේ නාලිකා ගණන වේ
167def\_\_init\_\_(self,\*,channels:int,channel\_multipliers:List[int],n\_resnet\_blocks:int,168out\_channels:int,z\_channels:int):
177super().\_\_init\_\_()
විවිධ විභේදන වල කුට්ටි ගණන. එක් එක් ඉහළ මට්ටමේ කොටස අවසානයේ විභේදනය අඩකින් යුක්ත වේ
181num\_resolutions=len(channel\_multipliers)
ප්රතිලෝම අනුපිළිවෙලෙහි එක් එක් ඉහළ මට්ටමේ බ්ලොක් එකේ නාලිකා ගණන
184channels\_list=[m\*channelsforminchannel\_multipliers]
ඉහළ මට්ටමේ බ්ලොක් එකේ නාලිකා ගණන
187channels=channels\_list[-1]
කාවැද්දීමේ අවකාශය සිතියම් ගත කරන මූලික3×3 කැටි ගැස්වීමේ ස්ථරයchannels
190self.conv\_in=nn.Conv2d(z\_channels,channels,3,stride=1,padding=1)
අවධානය සහිත රෙස්නෙට් බ්ලොක්
193self.mid=nn.Module()194self.mid.block\_1=ResnetBlock(channels,channels)195self.mid.attn\_1=AttnBlock(channels)196self.mid.block\_2=ResnetBlock(channels,channels)
ඉහළ මට්ටමේ කුට්ටි ලැයිස්තුව
199self.up=nn.ModuleList()
ඉහළ මට්ටමේ කුට්ටි සාදන්න
201foriinreversed(range(num\_resolutions)):
සෑම ඉහළ මට්ටමේ බ්ලොක් එකක්ම බහු රෙස්නෙට් බ්ලොක් සහ ඉහළ නියැදීම් වලින් සමන්විත වේ
203resnet\_blocks=nn.ModuleList()
රෙස්නෙට් බ්ලොක් එකතු කරන්න
205for\_inrange(n\_resnet\_blocks+1):206resnet\_blocks.append(ResnetBlock(channels,channels\_list[i]))207channels=channels\_list[i]
ඉහළ මට්ටමේ බ්ලොක්
209up=nn.Module()210up.block=resnet\_blocks
පළමුවැන්න හැර එක් එක් ඉහළ මට්ටමේ කොටස අවසානයේ ඉහළට නියැදීම
212ifi!=0:213up.upsample=UpSample(channels)214else:215up.upsample=nn.Identity()
මුරපොලට අනුකූල වීමට සූදානම් වන්න
217self.up.insert(0,up)
3×3සංකෝචනය සමඟ රූප අවකාශයට සිතියම
220self.norm\_out=normalization(channels)221self.conv\_out=nn.Conv2d(channels,out\_channels,3,stride=1,padding=1)
z හැඩය සහිත කාවැද්දීම tensor වේ[batch_size, z_channels, z_height, z_height]223defforward(self,z:torch.Tensor):
ආරම්භක කැටි ගැස්මchannels සමඟ සිතියම
229h=self.conv\_in(z)
අවධානය සහිත රෙස්නෙට් බ්ලොක්
232h=self.mid.block\_1(h)233h=self.mid.attn\_1(h)234h=self.mid.block\_2(h)
ඉහළ මට්ටමේ කුට්ටි
237forupinreversed(self.up):
රෙස්නෙට් බ්ලොක්
239forblockinup.block:240h=block(h)
ඉහළට නියැදීම
242h=up.upsample(h)
රූප අවකාශයට සාමාන්යකරණය කර සිතියම් ගත කරන්න
245h=self.norm\_out(h)246h=swish(h)247img=self.conv\_out(h)
250returnimg
253classGaussianDistribution:
parameters හැඩයේ කාවැද්දීම පිළිබඳ විචලනයන්ගේ මාධ්යයන් සහ ලොග් වේ[batch_size, z_channels * 2, z_height, z_height]258def\_\_init\_\_(self,parameters:torch.Tensor):
භේදය මධ්යන්යය සහ විචලතාව ලඝු-සටහන
264self.mean,log\_var=torch.chunk(parameters,2,dim=1)
විචල්යයන්ගේ ලොග් දැමීම
266self.log\_var=torch.clamp(log\_var,-30.0,20.0)
සම්මත අපගමනය ගණනය කරන්න
268self.std=torch.exp(0.5\*self.log\_var)
270defsample(self):
බෙදාහැරීමෙන් නියැදිය
272returnself.mean+self.std\*torch.randn\_like(self.std)
275classAttnBlock(nn.Module):
channels යනු නාලිකා ගණන280def\_\_init\_\_(self,channels:int):
284super().\_\_init\_\_()
කණ්ඩායම් සාමාන්යකරණය
286self.norm=normalization(channels)
විමසුම්, යතුර සහ අගය සිතියම්
288self.q=nn.Conv2d(channels,channels,1)289self.k=nn.Conv2d(channels,channels,1)290self.v=nn.Conv2d(channels,channels,1)
අවසාන1×1 කැටි ගැසුණු ස්ථරය
292self.proj\_out=nn.Conv2d(channels,channels,1)
අවධානය පරිමාණ සාධකය
294self.scale=channels\*\*-0.5
x හැඩයේ ආතතිකය වේ[batch_size, channels, height, width]296defforward(self,x:torch.Tensor):
සාමාන්ය කරන්නx
301x\_norm=self.norm(x)
විමසුම, යතුර සහ දෛශික කාවැද්දීම් ලබා ගන්න
303q=self.q(x\_norm)304k=self.k(x\_norm)305v=self.v(x\_norm)
විමසුම, ප්රධාන සහ දෛශික කාවැද්දීම්[batch_size, channels, height, width] වෙත නැවත සකස් කරන්න[batch_size, channels, height * width]
310b,c,h,w=q.shape311q=q.view(b,c,h\*w)312k=k.view(b,c,h\*w)313v=v.view(b,c,h\*w)
ගණනය කරන්නseqsoftmax(dkeyQK⊤)
316attn=torch.einsum('bci,bcj-\>bij',q,k)\*self.scale317attn=F.softmax(attn,dim=2)
ගණනය කරන්නseqsoftmax(dkeyQK⊤)V
320out=torch.einsum('bij,bcj-\>bci',attn,v)
නැවත සකස් කරන්න[batch_size, channels, height, width]
323out=out.view(b,c,h,w)
අවසාන1×1 කැටි ගැසුණු ස්ථරය
325out=self.proj\_out(out)
අවශේෂ සම්බන්ධතාවය එක් කරන්න
328returnx+out
331classUpSample(nn.Module):
channels යනු නාලිකා ගණන335def\_\_init\_\_(self,channels:int):
339super().\_\_init\_\_()
3×3කැටි ගැසීමේ සිතියම්කරණය
341self.conv=nn.Conv2d(channels,channels,3,padding=1)
x හැඩය සහිත ආදාන විශේෂාංග සිතියමයි[batch_size, channels, height, width]343defforward(self,x:torch.Tensor):
සාධකයක් අනුව ඉහළ නියැදිය2
348x=F.interpolate(x,scale\_factor=2.0,mode="nearest")
කැටි ගැසිම යොදන්න
350returnself.conv(x)
353classDownSample(nn.Module):
channels යනු නාලිකා ගණන357def\_\_init\_\_(self,channels:int):
361super().\_\_init\_\_()
3×3ක සාධකයක් විසින් පහළ-නියැදි2 කිරීමට stride දිග සමග convolution2
363self.conv=nn.Conv2d(channels,channels,3,stride=2,padding=0)
x හැඩය සහිත ආදාන විශේෂාංග සිතියමයි[batch_size, channels, height, width]365defforward(self,x:torch.Tensor):
පෑඩින් එකතු කරන්න
370x=F.pad(x,(0,1,0,1),mode="constant",value=0)
කැටි ගැසිම යොදන්න
372returnself.conv(x)
375classResnetBlock(nn.Module):
in_channels යනු ආදානයේ නාලිකා ගණනout_channels නිමැවුමේ නාලිකා ගණන වේ379def\_\_init\_\_(self,in\_channels:int,out\_channels:int):
384super().\_\_init\_\_()
පළමු සාමාන්යකරණය සහ කැටි ගැසුණු ස්ථරය
386self.norm1=normalization(in\_channels)387self.conv1=nn.Conv2d(in\_channels,out\_channels,3,stride=1,padding=1)
දෙවන සාමාන්යකරණය සහ කැටි ගැසුණු ස්ථරය
389self.norm2=normalization(out\_channels)390self.conv2=nn.Conv2d(out\_channels,out\_channels,3,stride=1,padding=1)
in_channels අවශේෂ සම්බන්ධතාවය සඳහා ස්තරයout_channels සිතියම්ගත කිරීම
392ifin\_channels!=out\_channels:393self.nin\_shortcut=nn.Conv2d(in\_channels,out\_channels,1,stride=1,padding=0)394else:395self.nin\_shortcut=nn.Identity()
x හැඩය සහිත ආදාන විශේෂාංග සිතියමයි[batch_size, channels, height, width]397defforward(self,x:torch.Tensor):
402h=x
පළමු සාමාන්යකරණය සහ කැටි ගැසුණු ස්ථරය
405h=self.norm1(h)406h=swish(h)407h=self.conv1(h)
දෙවන සාමාන්යකරණය සහ කැටි ගැසුණු ස්ථරය
410h=self.norm2(h)411h=swish(h)412h=self.conv2(h)
සිතියම සහ අවශේෂ එකතු කරන්න
415returnself.nin\_shortcut(x)+h
x⋅σ(x)
418defswish(x:torch.Tensor):
424returnx\*torch.sigmoid(x)
මෙය උපකාරක ශ්රිතයක් වන අතර ස්ථාවර කණ්ඩායම් ගණන සහeps .
427defnormalization(channels:int):
433returnnn.GroupNorm(num\_groups=32,num\_channels=channels,eps=1e-6)