Back to Annotated Deep Learning Paper Implementations

ස්ථාවර විසරණය යූ-නෙට් සඳහා ට්රාන්ස්ෆෝමර්

docs/si/diffusion/stable_diffusion/model/unet_attention.html

latest15.6 KB
Original Source

homediffusionstable_diffusionmodel

View code on Github

#

ස්ථාවර විසරණය යූ-නෙට් සඳහා ට්රාන්ස්ෆෝමර්

මෙය ලබා දෙන යූ-නෙට් හි භාවිතා කරන ට්රාන්ස්ෆෝමර් මොඩියුලය ක්රියාත්මක කරයිϵcond​(xt​,c)

අපි ආදර්ශ අර්ථ දැක්වීම තබා ඇති අතර කොම්විස්/ස්ථාවර විසරණ සිට නොවෙනස්ව නම් කිරීම අපට මුරපොලවල් කෙලින්ම පැටවිය හැකි වන පරිදි.

19fromtypingimportOptional2021importtorch22importtorch.nn.functionalasF23fromtorchimportnn

#

අවකාශීය ට්රාන්ස්ෆෝමර්

26classSpatialTransformer(nn.Module):

#

  • channels විශේෂාංග සිතියමේ නාලිකා ගණන
  • n_heads අවධානය යොමු ප්රධානීන් සංඛ්යාව වේ
  • n_layers ට්රාන්ස්ෆෝමර් ස්ථර ගණන
  • d_cond යනු කොන්දේසි සහිත කාවැද්දිවල ප්රමාණයයි
31def\_\_init\_\_(self,channels:int,n\_heads:int,n\_layers:int,d\_cond:int):

#

38super().\_\_init\_\_()

#

ආරම්භක කණ්ඩායම් සාමාන්යකරණය

40self.norm=torch.nn.GroupNorm(num\_groups=32,num\_channels=channels,eps=1e-6,affine=True)

#

මූලික1×1 කැටි ගැසිම

42self.proj\_in=nn.Conv2d(channels,channels,kernel\_size=1,stride=1,padding=0)

#

ට්රාන්ස්ෆෝමර් ස්ථර

45self.transformer\_blocks=nn.ModuleList(46[BasicTransformerBlock(channels,n\_heads,channels//n\_heads,d\_cond=d\_cond)for\_inrange(n\_layers)]47)

#

අවසාන1×1 කැටි ගැසිම

50self.proj\_out=nn.Conv2d(channels,channels,kernel\_size=1,stride=1,padding=0)

#

  • x හැඩයේ විශේෂාංග සිතියමයි[batch_size, channels, height, width]

  • cond හැඩයේ කොන්දේසි සහිත කාවැද්දීම් වේ[batch_size, n_cond, d_cond]

52defforward(self,x:torch.Tensor,cond:torch.Tensor):

#

හැඩය ලබා ගන්න[batch_size, channels, height, width]

58b,c,h,w=x.shape

#

අවශේෂ සම්බන්ධතාවය සඳහා

60x\_in=x

#

සාමාන්‍ය කරන්න

62x=self.norm(x)

#

මූලික1×1 කැටි ගැසිම

64x=self.proj\_in(x)

#

සිට සම්ප්රේෂණය කර නැවත හැඩගස්වා[batch_size, channels, height, width] ගන්න[batch_size, height * width, channels]

67x=x.permute(0,2,3,1).view(b,h\*w,c)

#

ට්රාන්ස්ෆෝමර් ස්ථර යොදන්න

69forblockinself.transformer\_blocks:70x=block(x,cond)

#

නැවත හැඩගස්වා සිට සම්ප්රේෂණය[batch_size, height * width, channels] කරන්න[batch_size, channels, height, width]

73x=x.view(b,h,w,c).permute(0,3,1,2)

#

අවසාන1×1 කැටි ගැසිම

75x=self.proj\_out(x)

#

අවශේෂ එකතු කරන්න

77returnx+x\_in

#

ට්රාන්ස්ෆෝමර් ස්ථරය

80classBasicTransformerBlock(nn.Module):

#

  • d_model ආදාන කාවැද්දීමේ ප්රමාණයයි
  • n_heads අවධානය යොමු ප්රධානීන් සංඛ්යාව වේ
  • d_head අවධානය යොමු හිසෙහි ප්රමාණයයි
  • d_cond යනු කොන්දේසි සහිත කාවැද්දීම් වල ප්රමාණයයි
85def\_\_init\_\_(self,d\_model:int,n\_heads:int,d\_head:int,d\_cond:int):

#

92super().\_\_init\_\_()

#

ස්වයං අවධානය ස්ථරය හා පෙර-සම්මත ස්ථරය

94self.attn1=CrossAttention(d\_model,d\_model,n\_heads,d\_head)95self.norm1=nn.LayerNorm(d\_model)

#

හරස් අවධානය ස්ථරය සහ පෙර-සම්මත ස්ථරය

97self.attn2=CrossAttention(d\_model,d\_cond,n\_heads,d\_head)98self.norm2=nn.LayerNorm(d\_model)

#

Feed-ඉදිරි ජාලය සහ පෙර-සම්මත ස්ථරය

100self.ff=FeedForward(d\_model)101self.norm3=nn.LayerNorm(d\_model)

#

  • x හැඩයේ ආදාන කාවැද්දීම් වේ[batch_size, height * width, d_model]

  • cond හැඩයේ කොන්දේසි සහිත කාවැද්දීම් වේ[batch_size, n_cond, d_cond]

103defforward(self,x:torch.Tensor,cond:torch.Tensor):

#

ස්වයං අවධානය

109x=self.attn1(self.norm1(x))+x

#

කන්ඩිෂනේෂන් සමඟ හරස් අවධානය

111x=self.attn2(self.norm2(x),cond=cond)+x

#

Feed-ඉදිරි ජාලය

113x=self.ff(self.norm3(x))+x

#

115returnx

#

හරස් අවධානය ස්ථරය

කොන්දේසි සහිත කාවැද්දීම් නිශ්චිතව දක්වා නොමැති විට මෙය ස්වයං අවධානයට යොමු වේ.

118classCrossAttention(nn.Module):

#

125use\_flash\_attention:bool=False

#

  • d_model ආදාන කාවැද්දීමේ ප්රමාණයයි
  • n_heads අවධානය යොමු ප්රධානීන් සංඛ්යාව වේ
  • d_head අවධානය යොමු හිසෙහි ප්රමාණයයි
  • d_cond යනු කොන්දේසි සහිත කාවැද්දීම් වල ප්රමාණයයි
  • is_inplace මතකය ඉතිරි කර ගැනීම සඳහා අවධානය softmax ගණනය inplace ඉටු කිරීමට යන්න නියම
127def\_\_init\_\_(self,d\_model:int,d\_cond:int,n\_heads:int,d\_head:int,is\_inplace:bool=True):

#

136super().\_\_init\_\_()137138self.is\_inplace=is\_inplace139self.n\_heads=n\_heads140self.d\_head=d\_head

#

අවධානය පරිමාණ සාධකය

143self.scale=d\_head\*\*-0.5

#

විමසුම්, යතුර සහ අගය සිතියම්

146d\_attn=d\_head\*n\_heads147self.to\_q=nn.Linear(d\_model,d\_attn,bias=False)148self.to\_k=nn.Linear(d\_cond,d\_attn,bias=False)149self.to\_v=nn.Linear(d\_cond,d\_attn,bias=False)

#

අවසාන රේඛීය ස්ථරය

152self.to\_out=nn.Sequential(nn.Linear(d\_attn,d\_model))

#

සැකසුම ෆ්ලෑෂ් අවධානය. ෆ්ලෑෂ් අවධානය භාවිතා කරනු ලබන්නේ එය ස්ථාපනය කර ඇත්නම් සහCrossAttention.use_flash_attention එය සකසා ඇත්නම් පමණිTrue .

157try:

#

ක්ලෝනකරණය කිරීමෙන් ඔබට ෆ්ලෑෂ් අවධානය ස්ථාපනය කළ හැකිය Github repo, https://github.com/HazyResearch/flash-attention ඉන්පසු ධාවනයpython setup.py install

161fromflash\_attn.flash\_attentionimportFlashAttention162self.flash=FlashAttention()

#

පරිමාණ තිත් නිෂ්පාදන අවධානය සඳහා පරිමාණය සකසන්න.

164self.flash.softmax\_scale=self.scale

#

එය ස්ථාපනය කර නොමැතිNone නම් සකසන්න

166exceptImportError:167self.flash=None

#

  • x හැඩයේ ආදාන කාවැද්දීම් වේ[batch_size, height * width, d_model]

  • cond හැඩයේ කොන්දේසි සහිත කාවැද්දීම් වේ[batch_size, n_cond, d_cond]

169defforward(self,x:torch.Tensor,cond:Optional[torch.Tensor]=None):

#

None අපි ස්වයං අවධානය යොමුcond කරන්නේ නම්

176has\_cond=condisnotNone177ifnothas\_cond:178cond=x

#

විමසුම, යතුර සහ අගය දෛශික ලබා ගන්න

181q=self.to\_q(x)182k=self.to\_k(cond)183v=self.to\_v(cond)

#

ෆ්ලෑෂ් අවධානය ලබා ගත හැකි නම් සහ හිස ප්රමාණය අඩු හෝ සමාන නම් භාවිතා කරන්න128

186ifCrossAttention.use\_flash\_attentionandself.flashisnotNoneandnothas\_condandself.d\_head\<=128:187returnself.flash\_attention(q,k,v)

#

එසේ නොමැති නම්, සාමාන්ය අවධානයට වැටීම

189else:190returnself.normal\_attention(q,k,v)

#

ෆ්ලෑෂ් අවධානය

  • q හිස් බෙදීමට පෙර විමසුම් දෛශික, හැඩයෙන්[batch_size, seq, d_attn]
  • k හිස් බෙදීමට පෙර විමසුම් දෛශික, හැඩයෙන්[batch_size, seq, d_attn]
  • v හිස් බෙදීමට පෙර විමසුම් දෛශික, හැඩයෙන්[batch_size, seq, d_attn]
192defflash\_attention(self,q:torch.Tensor,k:torch.Tensor,v:torch.Tensor):

#

අනුක්රමික අක්ෂය ඔස්සේ කණ්ඩායම් ප්රමාණය සහ මූලද්රව්ය ගණන ලබා ගන්න (width * height )

202batch\_size,seq\_len,\_=q.shape

#

ෆ්ලෑෂ් අවධානය සඳහාv දෛශිකq``k , හැඩයේ තනි ටෙන්සරයක් ලබා ගැනීමට[batch_size, seq_len, 3, n_heads * d_head]

206qkv=torch.stack((q,k,v),dim=2)

#

හිස් බෙදන්න

208qkv=qkv.view(batch\_size,seq\_len,3,self.n\_heads,self.d\_head)

#

ෆ්ලෑෂ් අවධානය හිස් ප්රමාණ සඳහා ක්රියා කරයි32``128 ,64 සහ, එබැවින් මෙම ප්රමාණයට සරිලන පරිදි හිස් පෑඩ් කළ යුතුය.

212ifself.d\_head\<=32:213pad=32-self.d\_head214elifself.d\_head\<=64:215pad=64-self.d\_head216elifself.d\_head\<=128:217pad=128-self.d\_head218else:219raiseValueError(f'Head size ${self.d\_head} too large for Flash Attention')

#

හිස් පෑඩ් කරන්න

222ifpad:223qkv=torch.cat((qkv,qkv.new\_zeros(batch\_size,seq\_len,3,self.n\_heads,pad)),dim=-1)

#

අවධානය ගණනය කරන්නseqsoftmax​(dkey​​QK⊤​)V මෙය හැඩයේ ආතතිකයක් ලබා දෙයි[batch_size, seq_len, n_heads, d_padded]

228out,\_=self.flash(qkv)

#

අමතර හිස ප්රමාණය කපා

230out=out[:,:,:,:self.d\_head]

#

නැවත හැඩගස්වන්න[batch_size, seq_len, n_heads * d_head]

232out=out.reshape(batch\_size,seq\_len,self.n\_heads\*self.d\_head)

#

රේඛීය ස්ථරයක්[batch_size, height * width, d_model] සමඟ සිතියම

235returnself.to\_out(out)

#

සාමාන්ය අවධානය

  • q හිස් බෙදීමට පෙර විමසුම් දෛශික, හැඩයෙන්[batch_size, seq, d_attn]
  • k හිස් බෙදීමට පෙර විමසුම් දෛශික, හැඩයෙන්[batch_size, seq, d_attn]
  • v හිස් බෙදීමට පෙර විමසුම් දෛශික, හැඩයෙන්[batch_size, seq, d_attn]
237defnormal\_attention(self,q:torch.Tensor,k:torch.Tensor,v:torch.Tensor):

#

හැඩයේ හිස් වලට බෙදන්න[batch_size, seq_len, n_heads, d_head]

247q=q.view(\*q.shape[:2],self.n\_heads,-1)248k=k.view(\*k.shape[:2],self.n\_heads,-1)249v=v.view(\*v.shape[:2],self.n\_heads,-1)

#

අවධානය ගණනය කරන්නdkey​​QK⊤​

252attn=torch.einsum('bihd,bjhd-\>bhij',q,k)\*self.scale

#

සොෆ්ට්මැක්ස් ගණනය කරන්නseqsoftmax​(dkey​​QK⊤​)

256ifself.is\_inplace:257half=attn.shape[0]//2258attn[half:]=attn[half:].softmax(dim=-1)259attn[:half]=attn[:half].softmax(dim=-1)260else:261attn=attn.softmax(dim=-1)

#

අවධානය ප්රතිදානය ගණනයseqsoftmax​(dkey​​QK⊤​)V

265out=torch.einsum('bhij,bjhd-\>bihd',attn,v)

#

නැවත හැඩගස්වන්න[batch_size, height * width, n_heads * d_head]

267out=out.reshape(\*out.shape[:2],-1)

#

රේඛීය ස්ථරයක්[batch_size, height * width, d_model] සමඟ සිතියම

269returnself.to\_out(out)

#

Feed-ඉදිරි ජාලය

272classFeedForward(nn.Module):

#

  • d_model ආදාන කාවැද්දීමේ ප්රමාණයයි
  • d_mult සැඟවුණු ස්ථර ප්රමාණය සඳහා බහුකාර්ය සාධකයකි
277def\_\_init\_\_(self,d\_model:int,d\_mult:int=4):

#

282super().\_\_init\_\_()283self.net=nn.Sequential(284GeGLU(d\_model,d\_model\*d\_mult),285nn.Dropout(0.),286nn.Linear(d\_model\*d\_mult,d\_model)287)

#

289defforward(self,x:torch.Tensor):290returnself.net(x)

#

Glu සක්රිය කිරීම

GeGLU(x)=(xW+b)∗GELU(xV+c)

293classGeGLU(nn.Module):

#

300def\_\_init\_\_(self,d\_in:int,d\_out:int):301super().\_\_init\_\_()

#

ඒකාබද්ධ රේඛීය ප්රක්ෂේපණxW+b සහxV+c

303self.proj=nn.Linear(d\_in,d\_out\*2)

#

305defforward(self,x:torch.Tensor):

#

ලබාxW+b ගන්නxV+c

307x,gate=self.proj(x).chunk(2,dim=-1)

#

GeGLU(x)=(xW+b)∗GELU(xV+c)

309returnx\*F.gelu(gate)

Trending Research Paperslabml.ai