Back to Annotated Deep Learning Paper Implementations

හයිපර්නෙට්වර්ක්ස්- හයිපර්එල්එස්එම්

docs/si/hypernetworks/hyper_lstm.html

latest15.0 KB
Original Source

homehypernetworks

View code on Github

#

හයිපර්නෙට්වර්ක්ස්- හයිපර්එල්එස්එම්

අපි PyTorchභාවිතා විවරණ සමග, කඩදාසි HyperNetworksහඳුන්වා HyperlSTM ක්රියාත්මක කර ඇත. ඩේවිඩ් හාගේ මෙම බ්ලොග් සටහන හයිපර්නෙට්වර්ක්ස් පිළිබඳ හොඳ පැහැදිලි කිරීමක් ලබා දෙයි.

ෂේක්ස්පියර්දත්ත කට්ටලය පිළිබඳ පෙළ පුරෝකථනය කිරීම සඳහා හයිපර්එල්එස්ටීඑම් පුහුණු කරන අත්හදා බැලීමක් අපට තිබේ. මෙන්න කේතය සඳහා සබැඳිය: experiment.py

හයිපර්නෙට්වර්ක්ස්විශාල ජාලයක බර උත්පාදනය කිරීම සඳහා කුඩා ජාලයක් භාවිතා කරයි. ප්රභේද දෙකක් තිබේ: ස්ථිතික අධි-ජාල සහ ගතික අධි-ජාල. ස්ථිතික හයිපර්නෙට්වර්ක්ස් යනු සංවහන ජාලයක බර (කර්නල්) ජනනය කරන කුඩා ජාල ඇත. ඩයිනමික් හයිපර්නෙට්වර්ක්ස් එක් එක් පියවර සඳහා පුනරාවර්තන ස්නායුක ජාලයක පරාමිතීන් ජනනය කරයි. මෙය අවසාන වශයෙන් ක්රියාත්මක කිරීමයි.

ගතිකඅධි-ජාල

RNNහි එක් එක් පියවර සඳහා පරාමිතීන් නියතව පවතී. ගතික හයිපර්නෙට්වර්ක්ස් එක් එක් පියවර සඳහා විවිධ පරාමිතීන් ජනනය කරයි. හයිපර්එල්එස්ටීඑම් හි LSTM හි ව්යුහය ඇති නමුත් එක් එක් පියවරේ පරාමිතීන් කුඩා LSTM ජාලයකින් වෙනස් වේ.

මූලිකස්වරූපයෙන්, ඩයිනමික් හයිපර්නෙට්වර්ක් කුඩා පුනරාවර්තන ජාලයක් ඇති අතර එය විශාල පුනරාවර්තන ජාලයේ එක් එක් පරාමිති ටෙන්සරයට අනුරූප විශේෂාංග දෛශිකයක් ජනනය කරයි. විශාල ජාලයට යම් පරාමිතියක් Wh​ ඇතැයි කියමු කුඩා ජාලය විශේෂාංග දෛශිකයක් ජනනය කරන zh​ අතර රේඛීය පරිවර්තනයක් Wh​ ලෙස ගතිකව ගණනය කරමු zh​. උදාහරණයක් Wh​=⟨Whz​,zh​⟩ ලෙස 3-d ටෙන්සර් පරාමිතියක් ⟨.⟩ වන අතර එය ආතතිය-දෛශික ගුණ කිරීම වේ. Whz​ zh​ සාමාන්යයෙන් කුඩා පුනරාවර්තන ජාලයේ නිමැවුමේ රේඛීය පරිවර්තනයකි.

පරිගණකකරණයවෙනුවට බර පරිමාණය

විශාලපුනරාවර්තන ජාලයන් විශාල ගතිකව ගණනය කරන ලද පරාමිතීන් ඇත. මෙම ලක්ෂණය දෛශික රේඛීය පරිවර්තනය භාවිතා ගණනය කර zඇත. මෙම පරිවර්තනයට ඊටත් වඩා විශාල බර ආතතයක් අවශ්ය වේ. එනම්, හැඩය Wh​ ඇති විට Nh​×Nh​, Whz​ වනු ඇත Nh​×Nh​×Nz​.

මෙයජය ගැනීම සඳහා, එකම ප්රමාණයේ අනුකෘතියක එක් එක් පේළිය ගතිකව පරිමාණය කිරීමෙන් පුනරාවර්තන ජාලයේ බර පරාමිතීන් ගණනය කරමු.

d(z)=Whz​zh​Wh​=⎝⎛​d0​(z)Whd0​​d1​(z)Whd1​​...dNh​​(z)WhdNh​​​​⎠⎞​​

Nh​×Nh​ පරාමිති අනුකෘතියක් කොහෙද Whd​ .

මූලද්රව්ය-wiseානවන්ත ගුණ කිරීම සඳහා අප ගණනය Wh​hd(z)⊙(Whd​h) කරන ⊙ විට අපට මෙය තවදුරටත් ප්රශස්තිකරණය කළ හැකිය.

73fromtypingimportOptional,Tuple7475importtorch76fromtorchimportnn7778fromlabml\_helpers.moduleimportModule79fromlabml\_nn.lstmimportLSTMCell

#

හයිපර්එල්එස්ටීඑම්සෛලය

HyperLSTMසඳහා කුඩා ජාලය සහ විශාල ජාලය දෙකම LSTM ව්යුහය ඇත. මෙය කඩදාසි වල උපග්රන්ථය A.2.2 හි අර්ථ දක්වා ඇත.

82classHyperLSTMCell(Module):

#

input_size යනු ආදානයේ ප්රමාණය xt​, hidden_size LSTM හි ප්රමාණය hyper_size වන අතර කුඩා LSTM හි බර වෙනස් කරන කුඩා LSTM වල ප්රමාණයයි විශාල පිටත LSTM. n_z යනු LSTM බර වෙනස් කිරීම සඳහා භාවිතා කරන විශේෂාංග දෛශිකවල ප්රමාණයයි.

ගණනය zh​i,f,g,oකිරීම සඳහා අපි කුඩා LSTM හි ප්රතිදානය zbi,f,g,o​ භාවිතා කරන zxi,f,g,o​ අතර රේඛීය පරිවර්තනයන් භාවිතා කරමු. අපි නැවත රේඛීය පරිවර්තනයන් භාවිතා කරමින් dhi,f,g,o​(zh​i,f,g,o)dxi,f,g,o​(zxi,f,g,o​), සහ dbi,f,g,o​(zbi,f,g,o​) මේවායින් ගණනය කරමු. මේවා ප්රධාන LSTM හි බර සහ පක්ෂග්රාහී ආතතීන් පේළි පරිමාණය කිරීමට භාවිතා කරයි.

📝වන ගණනය z හා අනුක්රමික රේඛීය d පරිවර්තනයන් දෙකක් මෙම තනි රේඛීය පරිවර්තනය බවට ඒකාබද්ධ කළ හැකි බැවින්. කෙසේ වෙතත් අපි මෙය වෙන වෙනම ක්රියාත්මක කර ඇති අතර එමඟින් එය කඩදාසි වල විස්තරය සමඟ ගැලපේ.

90def\_\_init\_\_(self,input\_size:int,hidden\_size:int,hyper\_size:int,n\_z:int):

#

108super().\_\_init\_\_()

#

හයිපර්එල්එස්ටීඑම්වෙත ආදානය යනු ආදානය x^t​=(ht−1​xt​​) xt​ කොතැනද යන්න සහ ht−1​ පෙර පියවරේදී පිටත LSTM හි ප්රතිදානය වේ. එබැවින් ආදාන ප්රමාණය වේ hidden_size + input_size .

HyperlSTMහි ප්රතිදානය h^t​ සහ c^t​.

121self.hyper=LSTMCell(hidden\_size+input\_size,hyper\_size,layer\_norm=True)

#

zh​i,f,g,o=linhi,f,g,o​(h^t​) 🤔 කඩදාසි තුළ එය ටයිපෝ බව zh​i,f,g,o=linhi,f,g,o​(h^t−1​) මට හැඟෙන පරිදි නියම කර ඇත.

127self.z\_h=nn.Linear(hyper\_size,4\*n\_z)

#

zxi,f,g,o​=linxi,f,g,o​(h^t​)

129self.z\_x=nn.Linear(hyper\_size,4\*n\_z)

#

zbi,f,g,o​=linbi,f,g,o​(h^t​)

131self.z\_b=nn.Linear(hyper\_size,4\*n\_z,bias=False)

#

dhi,f,g,o​(zh​i,f,g,o)=lindhi,f,g,o​(zh​i,f,g,o)

134d\_h=[nn.Linear(n\_z,hidden\_size,bias=False)for\_inrange(4)]135self.d\_h=nn.ModuleList(d\_h)

#

dxi,f,g,o​(zxi,f,g,o​)=lindxi,f,g,o​(zxi,f,g,o​)

137d\_x=[nn.Linear(n\_z,hidden\_size,bias=False)for\_inrange(4)]138self.d\_x=nn.ModuleList(d\_x)

#

dbi,f,g,o​(zbi,f,g,o​)=lindbi,f,g,o​(zbi,f,g,o​)

140d\_b=[nn.Linear(n\_z,hidden\_size)for\_inrange(4)]141self.d\_b=nn.ModuleList(d\_b)

#

බරමැට්ට්රිස් Whi,f,g,o​

144self.w\_h=nn.ParameterList([nn.Parameter(torch.zeros(hidden\_size,hidden\_size))for\_inrange(4)])

#

බරමැට්ට්රිස් Wxi,f,g,o​

146self.w\_x=nn.ParameterList([nn.Parameter(torch.zeros(hidden\_size,input\_size))for\_inrange(4)])

#

ස්ථරයසාමාන්යකරණය

149self.layer\_norm=nn.ModuleList([nn.LayerNorm(hidden\_size)for\_inrange(4)])150self.layer\_norm\_c=nn.LayerNorm(hidden\_size)

#

152defforward(self,x:torch.Tensor,153h:torch.Tensor,c:torch.Tensor,154h\_hat:torch.Tensor,c\_hat:torch.Tensor):

#

x^t​=(ht−1​xt​​)

161x\_hat=torch.cat((h,x),dim=-1)

#

h^t​,c^t​=lstm(x^t​,h^t−1​,c^t−1​)

163h\_hat,c\_hat=self.hyper(x\_hat,h\_hat,c\_hat)

#

zh​i,f,g,o=linhi,f,g,o​(h^t​)

166z\_h=self.z\_h(h\_hat).chunk(4,dim=-1)

#

zxi,f,g,o​=linxi,f,g,o​(h^t​)

168z\_x=self.z\_x(h\_hat).chunk(4,dim=-1)

#

zbi,f,g,o​=linbi,f,g,o​(h^t​)

170z\_b=self.z\_b(h\_hat).chunk(4,dim=-1)

#

අපිගණනය කරමු if, g සහ ලූපයක් o තුළ

173ifgo=[]174foriinrange(4):

#

dhi,f,g,o​(zh​i,f,g,o)=lindhi,f,g,o​(zh​i,f,g,o)

176d\_h=self.d\_h[i](z\_h[i])

#

dxi,f,g,o​(zxi,f,g,o​)=lindxi,f,g,o​(zxi,f,g,o​)

178d\_x=self.d\_x[i](z\_x[i])

# i,f,g,o=LN(++​dhi,f,g,o​(zh​)⊙(Whi,f,g,o​ht−1​)dxi,f,g,o​(zx​)⊙(Whi,f,g,o​xt​)dbi,f,g,o​(zb​))​

185y=d\_h\*torch.einsum('ij,bj-\>bi',self.w\_h[i],h)+\186d\_x\*torch.einsum('ij,bj-\>bi',self.w\_x[i],x)+\187self.d\_b[i](z\_b[i])188189ifgo.append(self.layer\_norm[i](y))

#

it​,ft​,gt​,ot​

192i,f,g,o=ifgo

#

ct​=σ(ft​)⊙ct−1​+σ(it​)⊙tanh(gt​)

195c\_next=torch.sigmoid(f)\*c+torch.sigmoid(i)\*torch.tanh(g)

#

ht​=σ(ot​)⊙tanh(LN(ct​))

198h\_next=torch.sigmoid(o)\*torch.tanh(self.layer\_norm\_c(c\_next))199200returnh\_next,c\_next,h\_hat,c\_hat

#

හයිපර්එල්එස්ටීඑම්මොඩියුලය

203classHyperLSTM(Module):

#

HyperlSTMජාලයක් සාදන්න. n_layers

208def\_\_init\_\_(self,input\_size:int,hidden\_size:int,hyper\_size:int,n\_z:int,n\_layers:int):

#

213super().\_\_init\_\_()

#

තත්වයආරම්භ කිරීම සඳහා ප්රමාණ ගබඩා කරන්න

216self.n\_layers=n\_layers217self.hidden\_size=hidden\_size218self.hyper\_size=hyper\_size

#

එක්එක් ස්ථරය සඳහා සෛල සාදන්න. පළමු ස්ථරයට පමණක් ආදානය කෙලින්ම ලැබෙන බව සලකන්න. සෙසු ස්ථර පහත ස්ථරයෙන් ආදානය ලබා ගනී

222self.cells=nn.ModuleList([HyperLSTMCell(input\_size,hidden\_size,hyper\_size,n\_z)]+223[HyperLSTMCell(hidden\_size,hidden\_size,hyper\_size,n\_z)for\_in224range(n\_layers-1)])

#

  • x හැඩය [n_steps, batch_size, input_size] සහ
  • state ක tuple වේ h,c,h^,c^. h,c හැඩය [batch_size, hidden_size] සහ හැඩය h^,c^ ඇති [batch_size, hyper_size] .
226defforward(self,x:torch.Tensor,227state:Optional[Tuple[torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor]]=None):

#

234n\_steps,batch\_size=x.shape[:2]

#

නම්ශුන්ය සමඟ රාජ්යය ආරම්භ කරන්න None

237ifstateisNone:238h=[x.new\_zeros(batch\_size,self.hidden\_size)for\_inrange(self.n\_layers)]239c=[x.new\_zeros(batch\_size,self.hidden\_size)for\_inrange(self.n\_layers)]240h\_hat=[x.new\_zeros(batch\_size,self.hyper\_size)for\_inrange(self.n\_layers)]241c\_hat=[x.new\_zeros(batch\_size,self.hyper\_size)for\_inrange(self.n\_layers)]

#

243else:244(h,c,h\_hat,c\_hat)=state

#

එක්එක් ස්ථරයේ තත්වයන් ලබා ගැනීම සඳහා ආතතීන් ආපසු හරවන්න

📝ඔබට ටෙන්සර් සමඟ වැඩ කළ හැකි නමුත් මෙය නිදොස් කිරීමට පහසුය

248h,c=list(torch.unbind(h)),list(torch.unbind(c))249h\_hat,c\_hat=list(torch.unbind(h\_hat)),list(torch.unbind(c\_hat))

#

එක්එක් පියවරේදී අවසාන ස්ථරයේ ප්රතිදානයන් එකතු කරන්න

252out=[]253fortinrange(n\_steps):

#

පළමුස්ථරයට ආදානය යනු ආදානය ම වේ

255inp=x[t]

#

ස්ථරහරහා ලූප්

257forlayerinrange(self.n\_layers):

#

ස්ථරයේතත්වය ලබා ගන්න

259h[layer],c[layer],h\_hat[layer],c\_hat[layer]=\260self.cells[layer](inp,h[layer],c[layer],h\_hat[layer],c\_hat[layer])

#

ඊළඟස්ථරයට ආදානය මෙම ස්ථරයේ තත්වයයි

262inp=h[layer]

#

අවසාන h ස්ථරයේ ප්රතිදානය එකතු කරන්න

264out.append(h[-1])

#

ප්රතිදානයන්සහ ප්රාන්ත ගොඩගසන්න

267out=torch.stack(out)268h=torch.stack(h)269c=torch.stack(c)270h\_hat=torch.stack(h\_hat)271c\_hat=torch.stack(c\_hat)

#

274returnout,(h,c,h\_hat,c\_hat)

Trending Research Paperslabml.ai