docs/si/diffusion/stable_diffusion/model/unet_attention.html
homediffusionstable_diffusionmodel
මෙය ලබා දෙන යූ-නෙට් හි භාවිතා කරන ට්රාන්ස්ෆෝමර් මොඩියුලය ක්රියාත්මක කරයිϵcond(xt,c)
අපි ආදර්ශ අර්ථ දැක්වීම තබා ඇති අතර කොම්විස්/ස්ථාවර විසරණ සිට නොවෙනස්ව නම් කිරීම අපට මුරපොලවල් කෙලින්ම පැටවිය හැකි වන පරිදි.
19fromtypingimportOptional2021importtorch22importtorch.nn.functionalasF23fromtorchimportnn
26classSpatialTransformer(nn.Module):
channels විශේෂාංග සිතියමේ නාලිකා ගණනn_heads අවධානය යොමු ප්රධානීන් සංඛ්යාව වේn_layers ට්රාන්ස්ෆෝමර් ස්ථර ගණනd_cond යනු කොන්දේසි සහිත කාවැද්දිවල ප්රමාණයයි31def\_\_init\_\_(self,channels:int,n\_heads:int,n\_layers:int,d\_cond:int):
38super().\_\_init\_\_()
ආරම්භක කණ්ඩායම් සාමාන්යකරණය
40self.norm=torch.nn.GroupNorm(num\_groups=32,num\_channels=channels,eps=1e-6,affine=True)
මූලික1×1 කැටි ගැසිම
42self.proj\_in=nn.Conv2d(channels,channels,kernel\_size=1,stride=1,padding=0)
ට්රාන්ස්ෆෝමර් ස්ථර
45self.transformer\_blocks=nn.ModuleList(46[BasicTransformerBlock(channels,n\_heads,channels//n\_heads,d\_cond=d\_cond)for\_inrange(n\_layers)]47)
අවසාන1×1 කැටි ගැසිම
50self.proj\_out=nn.Conv2d(channels,channels,kernel\_size=1,stride=1,padding=0)
x හැඩයේ විශේෂාංග සිතියමයි[batch_size, channels, height, width]
cond හැඩයේ කොන්දේසි සහිත කාවැද්දීම් වේ[batch_size, n_cond, d_cond]
52defforward(self,x:torch.Tensor,cond:torch.Tensor):
හැඩය ලබා ගන්න[batch_size, channels, height, width]
58b,c,h,w=x.shape
අවශේෂ සම්බන්ධතාවය සඳහා
60x\_in=x
සාමාන්ය කරන්න
62x=self.norm(x)
මූලික1×1 කැටි ගැසිම
64x=self.proj\_in(x)
සිට සම්ප්රේෂණය කර නැවත හැඩගස්වා[batch_size, channels, height, width] ගන්න[batch_size, height * width, channels]
67x=x.permute(0,2,3,1).view(b,h\*w,c)
ට්රාන්ස්ෆෝමර් ස්ථර යොදන්න
69forblockinself.transformer\_blocks:70x=block(x,cond)
නැවත හැඩගස්වා සිට සම්ප්රේෂණය[batch_size, height * width, channels] කරන්න[batch_size, channels, height, width]
73x=x.view(b,h,w,c).permute(0,3,1,2)
අවසාන1×1 කැටි ගැසිම
75x=self.proj\_out(x)
අවශේෂ එකතු කරන්න
77returnx+x\_in
80classBasicTransformerBlock(nn.Module):
d_model ආදාන කාවැද්දීමේ ප්රමාණයයිn_heads අවධානය යොමු ප්රධානීන් සංඛ්යාව වේd_head අවධානය යොමු හිසෙහි ප්රමාණයයිd_cond යනු කොන්දේසි සහිත කාවැද්දීම් වල ප්රමාණයයි85def\_\_init\_\_(self,d\_model:int,n\_heads:int,d\_head:int,d\_cond:int):
92super().\_\_init\_\_()
ස්වයං අවධානය ස්ථරය හා පෙර-සම්මත ස්ථරය
94self.attn1=CrossAttention(d\_model,d\_model,n\_heads,d\_head)95self.norm1=nn.LayerNorm(d\_model)
හරස් අවධානය ස්ථරය සහ පෙර-සම්මත ස්ථරය
97self.attn2=CrossAttention(d\_model,d\_cond,n\_heads,d\_head)98self.norm2=nn.LayerNorm(d\_model)
Feed-ඉදිරි ජාලය සහ පෙර-සම්මත ස්ථරය
100self.ff=FeedForward(d\_model)101self.norm3=nn.LayerNorm(d\_model)
x හැඩයේ ආදාන කාවැද්දීම් වේ[batch_size, height * width, d_model]
cond හැඩයේ කොන්දේසි සහිත කාවැද්දීම් වේ[batch_size, n_cond, d_cond]
103defforward(self,x:torch.Tensor,cond:torch.Tensor):
ස්වයං අවධානය
109x=self.attn1(self.norm1(x))+x
කන්ඩිෂනේෂන් සමඟ හරස් අවධානය
111x=self.attn2(self.norm2(x),cond=cond)+x
Feed-ඉදිරි ජාලය
113x=self.ff(self.norm3(x))+x
115returnx
කොන්දේසි සහිත කාවැද්දීම් නිශ්චිතව දක්වා නොමැති විට මෙය ස්වයං අවධානයට යොමු වේ.
118classCrossAttention(nn.Module):
125use\_flash\_attention:bool=False
d_model ආදාන කාවැද්දීමේ ප්රමාණයයිn_heads අවධානය යොමු ප්රධානීන් සංඛ්යාව වේd_head අවධානය යොමු හිසෙහි ප්රමාණයයිd_cond යනු කොන්දේසි සහිත කාවැද්දීම් වල ප්රමාණයයිis_inplace මතකය ඉතිරි කර ගැනීම සඳහා අවධානය softmax ගණනය inplace ඉටු කිරීමට යන්න නියම127def\_\_init\_\_(self,d\_model:int,d\_cond:int,n\_heads:int,d\_head:int,is\_inplace:bool=True):
136super().\_\_init\_\_()137138self.is\_inplace=is\_inplace139self.n\_heads=n\_heads140self.d\_head=d\_head
අවධානය පරිමාණ සාධකය
143self.scale=d\_head\*\*-0.5
විමසුම්, යතුර සහ අගය සිතියම්
146d\_attn=d\_head\*n\_heads147self.to\_q=nn.Linear(d\_model,d\_attn,bias=False)148self.to\_k=nn.Linear(d\_cond,d\_attn,bias=False)149self.to\_v=nn.Linear(d\_cond,d\_attn,bias=False)
අවසාන රේඛීය ස්ථරය
152self.to\_out=nn.Sequential(nn.Linear(d\_attn,d\_model))
සැකසුම ෆ්ලෑෂ් අවධානය. ෆ්ලෑෂ් අවධානය භාවිතා කරනු ලබන්නේ එය ස්ථාපනය කර ඇත්නම් සහCrossAttention.use_flash_attention එය සකසා ඇත්නම් පමණිTrue .
157try:
ක්ලෝනකරණය කිරීමෙන් ඔබට ෆ්ලෑෂ් අවධානය ස්ථාපනය කළ හැකිය Github repo, https://github.com/HazyResearch/flash-attention ඉන්පසු ධාවනයpython setup.py install
161fromflash\_attn.flash\_attentionimportFlashAttention162self.flash=FlashAttention()
පරිමාණ තිත් නිෂ්පාදන අවධානය සඳහා පරිමාණය සකසන්න.
164self.flash.softmax\_scale=self.scale
එය ස්ථාපනය කර නොමැතිNone නම් සකසන්න
166exceptImportError:167self.flash=None
x හැඩයේ ආදාන කාවැද්දීම් වේ[batch_size, height * width, d_model]
cond හැඩයේ කොන්දේසි සහිත කාවැද්දීම් වේ[batch_size, n_cond, d_cond]
169defforward(self,x:torch.Tensor,cond:Optional[torch.Tensor]=None):
None අපි ස්වයං අවධානය යොමුcond කරන්නේ නම්
176has\_cond=condisnotNone177ifnothas\_cond:178cond=x
විමසුම, යතුර සහ අගය දෛශික ලබා ගන්න
181q=self.to\_q(x)182k=self.to\_k(cond)183v=self.to\_v(cond)
ෆ්ලෑෂ් අවධානය ලබා ගත හැකි නම් සහ හිස ප්රමාණය අඩු හෝ සමාන නම් භාවිතා කරන්න128
186ifCrossAttention.use\_flash\_attentionandself.flashisnotNoneandnothas\_condandself.d\_head\<=128:187returnself.flash\_attention(q,k,v)
එසේ නොමැති නම්, සාමාන්ය අවධානයට වැටීම
189else:190returnself.normal\_attention(q,k,v)
q හිස් බෙදීමට පෙර විමසුම් දෛශික, හැඩයෙන්[batch_size, seq, d_attn]k හිස් බෙදීමට පෙර විමසුම් දෛශික, හැඩයෙන්[batch_size, seq, d_attn]v හිස් බෙදීමට පෙර විමසුම් දෛශික, හැඩයෙන්[batch_size, seq, d_attn]192defflash\_attention(self,q:torch.Tensor,k:torch.Tensor,v:torch.Tensor):
අනුක්රමික අක්ෂය ඔස්සේ කණ්ඩායම් ප්රමාණය සහ මූලද්රව්ය ගණන ලබා ගන්න (width * height )
202batch\_size,seq\_len,\_=q.shape
ෆ්ලෑෂ් අවධානය සඳහාv දෛශිකq``k , හැඩයේ තනි ටෙන්සරයක් ලබා ගැනීමට[batch_size, seq_len, 3, n_heads * d_head]
206qkv=torch.stack((q,k,v),dim=2)
හිස් බෙදන්න
208qkv=qkv.view(batch\_size,seq\_len,3,self.n\_heads,self.d\_head)
ෆ්ලෑෂ් අවධානය හිස් ප්රමාණ සඳහා ක්රියා කරයි32``128 ,64 සහ, එබැවින් මෙම ප්රමාණයට සරිලන පරිදි හිස් පෑඩ් කළ යුතුය.
212ifself.d\_head\<=32:213pad=32-self.d\_head214elifself.d\_head\<=64:215pad=64-self.d\_head216elifself.d\_head\<=128:217pad=128-self.d\_head218else:219raiseValueError(f'Head size ${self.d\_head} too large for Flash Attention')
හිස් පෑඩ් කරන්න
222ifpad:223qkv=torch.cat((qkv,qkv.new\_zeros(batch\_size,seq\_len,3,self.n\_heads,pad)),dim=-1)
අවධානය ගණනය කරන්නseqsoftmax(dkeyQK⊤)V මෙය හැඩයේ ආතතිකයක් ලබා දෙයි[batch_size, seq_len, n_heads, d_padded]
228out,\_=self.flash(qkv)
අමතර හිස ප්රමාණය කපා
230out=out[:,:,:,:self.d\_head]
නැවත හැඩගස්වන්න[batch_size, seq_len, n_heads * d_head]
232out=out.reshape(batch\_size,seq\_len,self.n\_heads\*self.d\_head)
රේඛීය ස්ථරයක්[batch_size, height * width, d_model] සමඟ සිතියම
235returnself.to\_out(out)
q හිස් බෙදීමට පෙර විමසුම් දෛශික, හැඩයෙන්[batch_size, seq, d_attn]k හිස් බෙදීමට පෙර විමසුම් දෛශික, හැඩයෙන්[batch_size, seq, d_attn]v හිස් බෙදීමට පෙර විමසුම් දෛශික, හැඩයෙන්[batch_size, seq, d_attn]237defnormal\_attention(self,q:torch.Tensor,k:torch.Tensor,v:torch.Tensor):
හැඩයේ හිස් වලට බෙදන්න[batch_size, seq_len, n_heads, d_head]
247q=q.view(\*q.shape[:2],self.n\_heads,-1)248k=k.view(\*k.shape[:2],self.n\_heads,-1)249v=v.view(\*v.shape[:2],self.n\_heads,-1)
අවධානය ගණනය කරන්නdkeyQK⊤
252attn=torch.einsum('bihd,bjhd-\>bhij',q,k)\*self.scale
සොෆ්ට්මැක්ස් ගණනය කරන්නseqsoftmax(dkeyQK⊤)
256ifself.is\_inplace:257half=attn.shape[0]//2258attn[half:]=attn[half:].softmax(dim=-1)259attn[:half]=attn[:half].softmax(dim=-1)260else:261attn=attn.softmax(dim=-1)
අවධානය ප්රතිදානය ගණනයseqsoftmax(dkeyQK⊤)V
265out=torch.einsum('bhij,bjhd-\>bihd',attn,v)
නැවත හැඩගස්වන්න[batch_size, height * width, n_heads * d_head]
267out=out.reshape(\*out.shape[:2],-1)
රේඛීය ස්ථරයක්[batch_size, height * width, d_model] සමඟ සිතියම
269returnself.to\_out(out)
272classFeedForward(nn.Module):
d_model ආදාන කාවැද්දීමේ ප්රමාණයයිd_mult සැඟවුණු ස්ථර ප්රමාණය සඳහා බහුකාර්ය සාධකයකි277def\_\_init\_\_(self,d\_model:int,d\_mult:int=4):
282super().\_\_init\_\_()283self.net=nn.Sequential(284GeGLU(d\_model,d\_model\*d\_mult),285nn.Dropout(0.),286nn.Linear(d\_model\*d\_mult,d\_model)287)
289defforward(self,x:torch.Tensor):290returnself.net(x)
GeGLU(x)=(xW+b)∗GELU(xV+c)
293classGeGLU(nn.Module):
300def\_\_init\_\_(self,d\_in:int,d\_out:int):301super().\_\_init\_\_()
ඒකාබද්ධ රේඛීය ප්රක්ෂේපණxW+b සහxV+c
303self.proj=nn.Linear(d\_in,d\_out\*2)
305defforward(self,x:torch.Tensor):
ලබාxW+b ගන්නxV+c
307x,gate=self.proj(x).chunk(2,dim=-1)
GeGLU(x)=(xW+b)∗GELU(xV+c)
309returnx\*F.gelu(gate)