docs/rl/dqn/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/rl/dqn/ init.py)
This is a PyTorch implementation of paper Playing Atari with Deep Reinforcement Learning along with Dueling Network, Prioritized Replay and Double Q Network.
Here is the experiment and model implementation.
24fromtypingimportTuple2526importtorch27fromtorchimportnn2829fromlabmlimporttracker30fromlabml\_nn.rl.dqn.replay\_bufferimportReplayBuffer
We want to find optimal action-value function.
Q∗(s,a)Q∗(s,a)=πmaxE[rt+γrt+1+γ2rt+2+...∣st=s,at=a,π]=Es′∼ε[r+γa′maxQ∗(s′,a′)∣s,a]
In order to improve stability we use experience replay that randomly sample from previous experience U(D). We also use a Q network with a separate set of parameters θi− to calculate the target. θi− is updated periodically. This is according to paper Human Level Control Through Deep Reinforcement Learning.
So the loss function is, Li(θi)=E(s,a,r,s′)∼U(D)[(r+γa′maxQ(s′,a′;θi−)−Q(s,a;θi))2]
The max operator in the above calculation uses same network for both selecting the best action and for evaluating the value. That is, a′maxQ(s′,a′;θ)=Q(s′,argmaxa′Q(s′,a′;θ);θ) We use double Q-learning, where the argmax is taken from θi and the value is taken from θi−.
And the loss function becomes,
Li(θi)=E(s,a,r,s′)∼U(D)[(−r+γQ(s′,argmaxa′Q(s′,a′;θi);θi−)Q(s,a;θi))2]
33classQFuncLoss(nn.Module):
101def\_\_init\_\_(self,gamma:float):102super().\_\_init\_\_()103self.gamma=gamma104self.huber\_loss=nn.SmoothL1Loss(reduction='none')
q - Q(s;θi)action - adouble_q - Q(s′;θi)target_q - Q(s′;θi−)done - whether the game ended after taking the actionreward - rweights - weights of the samples from prioritized experienced replay106defforward(self,q:torch.Tensor,action:torch.Tensor,double\_q:torch.Tensor,107target\_q:torch.Tensor,done:torch.Tensor,reward:torch.Tensor,108weights:torch.Tensor)-\>Tuple[torch.Tensor,torch.Tensor]:
Q(s,a;θi)
120q\_sampled\_action=q.gather(-1,action.to(torch.long).unsqueeze(-1)).squeeze(-1)121tracker.add('q\_sampled\_action',q\_sampled\_action)
Gradients shouldn't propagate gradients r+γQ(s′,argmaxa′Q(s′,a′;θi);θi−)
129withtorch.no\_grad():
Get the best action at state s′ argmaxa′Q(s′,a′;θi)
133best\_next\_action=torch.argmax(double\_q,-1)
Get the q value from the target network for the best action at state s′ Q(s′,argmaxa′Q(s′,a′;θi);θi−)
139best\_next\_q\_value=target\_q.gather(-1,best\_next\_action.unsqueeze(-1)).squeeze(-1)
Calculate the desired Q value. We multiply by (1 - done) to zero out the next state Q values if the game ended.
r+γQ(s′,argmaxa′Q(s′,a′;θi);θi−)
150q\_update=reward+self.gamma\*best\_next\_q\_value\*(1-done)151tracker.add('q\_update',q\_update)
Temporal difference error δ is used to weigh samples in replay buffer
154td\_error=q\_sampled\_action-q\_update155tracker.add('td\_error',td\_error)
We take Huber loss instead of mean squared error loss because it is less sensitive to outliers
159losses=self.huber\_loss(q\_sampled\_action,q\_update)
Get weighted means
161loss=torch.mean(weights\*losses)162tracker.add('loss',loss)163164returntd\_error,loss