docs/rl/dqn/model.html
12importtorch13fromtorchimportnn
We are using a dueling network to calculate Q-values. Intuition behind dueling network architecture is that in most states the action doesn't matter, and in some states the action is significant. Dueling network allows this to be represented very well.
Qπ(s,a)Ea∼π(s)[Aπ(s,a)]=Vπ(s)+Aπ(s,a)=0
So we create two networks for V and A and get Q from them. Q(s,a)=V(s)+(A(s,a)−∣A∣1a′∈A∑A(s,a′)) We share the initial layers of the V and A networks.
17classModel(nn.Module):
48def\_\_init\_\_(self):49super().\_\_init\_\_()50self.conv=nn.Sequential(
The first convolution layer takes a 84×84 frame and produces a 20×20 frame
53nn.Conv2d(in\_channels=4,out\_channels=32,kernel\_size=8,stride=4),54nn.ReLU(),
The second convolution layer takes a 20×20 frame and produces a 9×9 frame
58nn.Conv2d(in\_channels=32,out\_channels=64,kernel\_size=4,stride=2),59nn.ReLU(),
The third convolution layer takes a 9×9 frame and produces a 7×7 frame
63nn.Conv2d(in\_channels=64,out\_channels=64,kernel\_size=3,stride=1),64nn.ReLU(),65)
A fully connected layer takes the flattened frame from third convolution layer, and outputs 512 features
70self.lin=nn.Linear(in\_features=7\*7\*64,out\_features=512)71self.activation=nn.ReLU()
This head gives the state value V
74self.state\_value=nn.Sequential(75nn.Linear(in\_features=512,out\_features=256),76nn.ReLU(),77nn.Linear(in\_features=256,out\_features=1),78)
This head gives the action value A
80self.action\_value=nn.Sequential(81nn.Linear(in\_features=512,out\_features=256),82nn.ReLU(),83nn.Linear(in\_features=256,out\_features=4),84)
86defforward(self,obs:torch.Tensor):
Convolution
88h=self.conv(obs)
Reshape for linear layers
90h=h.reshape((-1,7\*7\*64))
Linear layer
93h=self.activation(self.lin(h))
A
96action\_value=self.action\_value(h)
V
98state\_value=self.state\_value(h)
A(s,a)−∣A∣1∑a′∈AA(s,a′)
101action\_score\_centered=action\_value-action\_value.mean(dim=-1,keepdim=True)
Q(s,a)=V(s)+(A(s,a)−∣A∣1∑a′∈AA(s,a′))
103q=state\_value+action\_score\_centered104105returnq