Back to Annotated Deep Learning Paper Implementations

ගැඹුරුQ ජාල (DQN)

docs/si/rl/dqn/index.html

latest6.2 KB
Original Source

homerldqn

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/rl/dqn/ init.py)

#

ගැඹුරුQ ජාල (DQN)

මෙය PyTorch ක්රියාත්මක කිරීමකි Atari කඩදාසි සෙල්ලම් කිරීම ගැඹුරු ශක්තිමත් කිරීමේ ඉගෙනීම සහ ඩුවලිං ජාලය , ප්රමුඛතා නැවත ධාවනය සහ ද්විත්ව Q ජාලය සමඟ.

මෙන්න අත්හදා බැලීම සහ ආදර්ශ ක්රියාත්මක කිරීම.

25fromtypingimportTuple2627importtorch28fromtorchimportnn2930fromlabmlimporttracker31fromlabml\_helpers.moduleimportModule32fromlabml\_nn.rl.dqn.replay\_bufferimportReplayBuffer

#

ආකෘතිය පුහුණු කරන්න

ප්රශස්ත ක්රියාකාරී අගය ශ්රිතය සොයා ගැනීමට අපට අවශ්යය.

Q∗(s,a)Q∗(s,a)​=πmax​E[rt​+γrt+1​+γ2rt+2​+...∣st​=s,at​=a,π]=Es′∼ε​[r+γa′max​Q∗(s′,a′)∣s,a]​

ඉලක්ක ජාලය 🎯

ස්ථාවරත්වය වැඩි දියුණු කිරීම සඳහා අපි පෙර අත්දැකීම් වලින් අහඹු ලෙස නියැදිය අත්දැකීම් නැවත ධාවනය භාවිතා කරමුU(D). ඉලක්කය ගණනය කිරීම සඳහා වෙනම පරාමිතීන්θi−​ සමූහයක් සහිත Q ජාලයක් ද අපි භාවිතා කරමු. θi−​වරින් වර යාවත්කාලීන වේ. මෙය ගැඹුරු ශක්තිමත් කිරීමේ ඉගෙනීම තුළින් කඩදාසි මානව මට්ටම් පාලනයට අනුව ය.

එබැවින් පාඩු ශ්රිතය වන්නේ,Li​(θi​)=E(s,a,r,s′)∼U(D)​[(r+γa′max​Q(s′,a′;θi−​)−Q(s,a;θi​))2]

Qද්විත්ව-ඉගෙනුම්

ඉහත ගණනය කිරීමේ උපරිම ක්රියාකරු හොඳම ක්රියාව තෝරා ගැනීම සහ වටිනාකම ඇගයීම සඳහා එකම ජාලයක් භාවිතා කරයි. එනම්,a′max​Q(s′,a′;θ)=Q(s′,argmaxa′​Q(s′,a′;θ);θ) අපි ද්විත්ව Q- ඉගෙනීම භාවිතා කරමු,argmax එය ලබා ගන්නේ කොතැනින්දθi​ සහ වටිනාකම ලබාθi−​ ගනී.

පාඩු ශ්රිතය බවට පත්වේ,

Li​(θi​)=E(s,a,r,s′)∼U(D)​[(−​r+γQ(s′,argmaxa′​Q(s′,a′;θi​);θi−​)Q(s,a;θi​))2]​

35classQFuncLoss(Module):

#

103def\_\_init\_\_(self,gamma:float):104super().\_\_init\_\_()105self.gamma=gamma106self.huber\_loss=nn.SmoothL1Loss(reduction='none')

#

  • q - Q(s;θi​)
  • action - a
  • double_q - Q(s′;θi​)
  • target_q - Q(s′;θi−​)
  • done - පියවර ගැනීමෙන් පසු ක්රීඩාව අවසන් වූවාද යන්න
  • reward - r
  • weights - ප්රමුඛතාවය පළපුරුදු නැවත ධාවනය සිට සාම්පල බර
108defforward(self,q:torch.Tensor,action:torch.Tensor,double\_q:torch.Tensor,109target\_q:torch.Tensor,done:torch.Tensor,reward:torch.Tensor,110weights:torch.Tensor)-\>Tuple[torch.Tensor,torch.Tensor]:

#

Q(s,a;θi​)

122q\_sampled\_action=q.gather(-1,action.to(torch.long).unsqueeze(-1)).squeeze(-1)123tracker.add('q\_sampled\_action',q\_sampled\_action)

#

අනුක්රමිකඅනුක්රමික ප්රචාරය නොකළ යුතුය r+γQ(s′,argmaxa′​Q(s′,a′;θi​);θi−​)

131withtorch.no\_grad():

#

රාජ්යයෙන්හොඳම ක්රියාව ලබා ගන්න s′ argmaxa′​Q(s′,a′;θi​)

135best\_next\_action=torch.argmax(double\_q,-1)

#

රාජ්යහොඳම ක්රියාව සඳහා ඉලක්ක ජාලයෙන් q අගය ලබා ගන්න s′ Q(s′,argmaxa′​Q(s′,a′;θi​);θi−​)

141best\_next\_q\_value=target\_q.gather(-1,best\_next\_action.unsqueeze(-1)).squeeze(-1)

#

අපේක්ෂිතQ අගය ගණනය කරන්න. ක්රීඩාව අවසන් වූයේ නම් ඊළඟ රාජ්ය Q අගයන් ශුන්ය (1 - done) කිරීමට අපි ගුණ කරමු.

r+γQ(s′,argmaxa′​Q(s′,a′;θi​);θi−​)

152q\_update=reward+self.gamma\*best\_next\_q\_value\*(1-done)153tracker.add('q\_update',q\_update)

#

නැවතධාවනය කිරීමේ බෆරයේ සාම්පල කිරා මැන බැලීමට තාවකාලික වෙනස දෝෂය δ භාවිතා කරයි

156td\_error=q\_sampled\_action-q\_update157tracker.add('td\_error',td\_error)

#

එයoutliers අඩු සංවේදී නිසා අපි ඒ වෙනුවට මධ්යන්ය වර්ග දෝෂයක් අහිමි Huber අහිමි ගත

161losses=self.huber\_loss(q\_sampled\_action,q\_update)

#

බරතැබූ ක්රම ලබා ගන්න

163loss=torch.mean(weights\*losses)164tracker.add('loss',loss)165166returntd\_error,loss

Trending Research Paperslabml.ai