Back to Annotated Deep Learning Paper Implementations

Deep Q Networks (DQN)

docs/rl/dqn/index.html

latest4.1 KB
Original Source

homerldqn

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/rl/dqn/ init.py)

#

Deep Q Networks (DQN)

This is a PyTorch implementation of paper Playing Atari with Deep Reinforcement Learning along with Dueling Network, Prioritized Replay and Double Q Network.

Here is the experiment and model implementation.

24fromtypingimportTuple2526importtorch27fromtorchimportnn2829fromlabmlimporttracker30fromlabml\_nn.rl.dqn.replay\_bufferimportReplayBuffer

#

Train the model

We want to find optimal action-value function.

Q∗(s,a)Q∗(s,a)​=πmax​E[rt​+γrt+1​+γ2rt+2​+...∣st​=s,at​=a,π]=Es′∼ε​[r+γa′max​Q∗(s′,a′)∣s,a]​

Target network 🎯

In order to improve stability we use experience replay that randomly sample from previous experience U(D). We also use a Q network with a separate set of parameters θi−​ to calculate the target. θi−​ is updated periodically. This is according to paper Human Level Control Through Deep Reinforcement Learning.

So the loss function is, Li​(θi​)=E(s,a,r,s′)∼U(D)​[(r+γa′max​Q(s′,a′;θi−​)−Q(s,a;θi​))2]

Double Q-Learning

The max operator in the above calculation uses same network for both selecting the best action and for evaluating the value. That is, a′max​Q(s′,a′;θ)=Q(s′,argmaxa′​Q(s′,a′;θ);θ) We use double Q-learning, where the argmax is taken from θi​ and the value is taken from θi−​.

And the loss function becomes,

Li​(θi​)=E(s,a,r,s′)∼U(D)​[(−​r+γQ(s′,argmaxa′​Q(s′,a′;θi​);θi−​)Q(s,a;θi​))2]​

33classQFuncLoss(nn.Module):

#

101def\_\_init\_\_(self,gamma:float):102super().\_\_init\_\_()103self.gamma=gamma104self.huber\_loss=nn.SmoothL1Loss(reduction='none')

#

  • q - Q(s;θi​)
  • action - a
  • double_q - Q(s′;θi​)
  • target_q - Q(s′;θi−​)
  • done - whether the game ended after taking the action
  • reward - r
  • weights - weights of the samples from prioritized experienced replay
106defforward(self,q:torch.Tensor,action:torch.Tensor,double\_q:torch.Tensor,107target\_q:torch.Tensor,done:torch.Tensor,reward:torch.Tensor,108weights:torch.Tensor)-\>Tuple[torch.Tensor,torch.Tensor]:

#

Q(s,a;θi​)

120q\_sampled\_action=q.gather(-1,action.to(torch.long).unsqueeze(-1)).squeeze(-1)121tracker.add('q\_sampled\_action',q\_sampled\_action)

#

Gradients shouldn't propagate gradients r+γQ(s′,argmaxa′​Q(s′,a′;θi​);θi−​)

129withtorch.no\_grad():

#

Get the best action at state s′ argmaxa′​Q(s′,a′;θi​)

133best\_next\_action=torch.argmax(double\_q,-1)

#

Get the q value from the target network for the best action at state s′ Q(s′,argmaxa′​Q(s′,a′;θi​);θi−​)

139best\_next\_q\_value=target\_q.gather(-1,best\_next\_action.unsqueeze(-1)).squeeze(-1)

#

Calculate the desired Q value. We multiply by (1 - done) to zero out the next state Q values if the game ended.

r+γQ(s′,argmaxa′​Q(s′,a′;θi​);θi−​)

150q\_update=reward+self.gamma\*best\_next\_q\_value\*(1-done)151tracker.add('q\_update',q\_update)

#

Temporal difference error δ is used to weigh samples in replay buffer

154td\_error=q\_sampled\_action-q\_update155tracker.add('td\_error',td\_error)

#

We take Huber loss instead of mean squared error loss because it is less sensitive to outliers

159losses=self.huber\_loss(q\_sampled\_action,q\_update)

#

Get weighted means

161loss=torch.mean(weights\*losses)162tracker.add('loss',loss)163164returntd\_error,loss

labml.ai