Back to Annotated Deep Learning Paper Implementations

Deep Q Network (DQN) Model

docs/rl/dqn/model.html

latest2.8 KB
Original Source

homerldqn

View code on Github

#

Deep Q Network (DQN) Model

12importtorch13fromtorchimportnn

#

Dueling Network ⚔️ Model for Q Values

We are using a dueling network to calculate Q-values. Intuition behind dueling network architecture is that in most states the action doesn't matter, and in some states the action is significant. Dueling network allows this to be represented very well.

Qπ(s,a)Ea∼π(s)​[Aπ(s,a)]​=Vπ(s)+Aπ(s,a)=0​

So we create two networks for V and A and get Q from them. Q(s,a)=V(s)+(A(s,a)−∣A∣1​a′∈A∑​A(s,a′)) We share the initial layers of the V and A networks.

17classModel(nn.Module):

#

48def\_\_init\_\_(self):49super().\_\_init\_\_()50self.conv=nn.Sequential(

#

The first convolution layer takes a 84×84 frame and produces a 20×20 frame

53nn.Conv2d(in\_channels=4,out\_channels=32,kernel\_size=8,stride=4),54nn.ReLU(),

#

The second convolution layer takes a 20×20 frame and produces a 9×9 frame

58nn.Conv2d(in\_channels=32,out\_channels=64,kernel\_size=4,stride=2),59nn.ReLU(),

#

The third convolution layer takes a 9×9 frame and produces a 7×7 frame

63nn.Conv2d(in\_channels=64,out\_channels=64,kernel\_size=3,stride=1),64nn.ReLU(),65)

#

A fully connected layer takes the flattened frame from third convolution layer, and outputs 512 features

70self.lin=nn.Linear(in\_features=7\*7\*64,out\_features=512)71self.activation=nn.ReLU()

#

This head gives the state value V

74self.state\_value=nn.Sequential(75nn.Linear(in\_features=512,out\_features=256),76nn.ReLU(),77nn.Linear(in\_features=256,out\_features=1),78)

#

This head gives the action value A

80self.action\_value=nn.Sequential(81nn.Linear(in\_features=512,out\_features=256),82nn.ReLU(),83nn.Linear(in\_features=256,out\_features=4),84)

#

86defforward(self,obs:torch.Tensor):

#

Convolution

88h=self.conv(obs)

#

Reshape for linear layers

90h=h.reshape((-1,7\*7\*64))

#

Linear layer

93h=self.activation(self.lin(h))

#

A

96action\_value=self.action\_value(h)

#

V

98state\_value=self.state\_value(h)

#

A(s,a)−∣A∣1​∑a′∈A​A(s,a′)

101action\_score\_centered=action\_value-action\_value.mean(dim=-1,keepdim=True)

#

Q(s,a)=V(s)+(A(s,a)−∣A∣1​∑a′∈A​A(s,a′))

103q=state\_value+action\_score\_centered104105returnq

labml.ai