docs/rl/dqn/experiment.html
This experiment trains a Deep Q Network (DQN) to play Atari Breakout game on OpenAI Gym. It runs the game environments on multiple processes to sample efficiently.
15importnumpyasnp16importtorch1718fromlabmlimporttracker,experiment,logger,monit19fromlabml.internal.configs.dynamic\_hyperparamimportFloatDynamicHyperParam20fromlabml\_nn.helpers.scheduleimportPiecewise21fromlabml\_nn.rl.dqnimportQFuncLoss22fromlabml\_nn.rl.dqn.modelimportModel23fromlabml\_nn.rl.dqn.replay\_bufferimportReplayBuffer24fromlabml\_nn.rl.gameimportWorker
Select device
27iftorch.cuda.is\_available():28device=torch.device("cuda:0")29else:30device=torch.device("cpu")
Scale observations from [0, 255] to [0, 1]
33defobs\_to\_torch(obs:np.ndarray)-\>torch.Tensor:
35returntorch.tensor(obs,dtype=torch.float32,device=device)/255.
38classTrainer:
43def\_\_init\_\_(self,\*,44updates:int,epochs:int,45n\_workers:int,worker\_steps:int,mini\_batch\_size:int,46update\_target\_model:int,47learning\_rate:FloatDynamicHyperParam,48):
number of workers
50self.n\_workers=n\_workers
steps sampled on each update
52self.worker\_steps=worker\_steps
number of training iterations
54self.train\_epochs=epochs
number of updates
57self.updates=updates
size of mini batch for training
59self.mini\_batch\_size=mini\_batch\_size
update target network every 250 update
62self.update\_target\_model=update\_target\_model
learning rate
65self.learning\_rate=learning\_rate
exploration as a function of updates
68self.exploration\_coefficient=Piecewise(69[70(0,1.0),71(25\_000,0.1),72(self.updates/2,0.01)73],outside\_value=0.01)
β for replay buffer as a function of updates
76self.prioritized\_replay\_beta=Piecewise(77[78(0,0.4),79(self.updates,1)80],outside\_value=1)
Replay buffer with α=0.6. Capacity of the replay buffer must be a power of 2.
83self.replay\_buffer=ReplayBuffer(2\*\*14,0.6)
Model for sampling and training
86self.model=Model().to(device)
target model to get Q(s′;θi−)
88self.target\_model=Model().to(device)
create workers
91self.workers=[Worker(47+i)foriinrange(self.n\_workers)]
initialize tensors for observations
94self.obs=np.zeros((self.n\_workers,4,84,84),dtype=np.uint8)
reset the workers
97forworkerinself.workers:98worker.child.send(("reset",None))
get the initial observations
101fori,workerinenumerate(self.workers):102self.obs[i]=worker.child.recv()
loss function
105self.loss\_func=QFuncLoss(0.99)
optimizer
107self.optimizer=torch.optim.Adam(self.model.parameters(),lr=2.5e-4)
When sampling actions we use a ϵ-greedy strategy, where we take a greedy action with probabiliy 1−ϵ and take a random action with probability ϵ. We refer to ϵ as exploration_coefficient .
109def\_sample\_action(self,q\_value:torch.Tensor,exploration\_coefficient:float):
Sampling doesn't need gradients
119withtorch.no\_grad():
Sample the action with highest Q-value. This is the greedy action.
121greedy\_action=torch.argmax(q\_value,dim=-1)
Uniformly sample and action
123random\_action=torch.randint(q\_value.shape[-1],greedy\_action.shape,device=q\_value.device)
Whether to chose greedy action or the random action
125is\_choose\_rand=torch.rand(greedy\_action.shape,device=q\_value.device)\<exploration\_coefficient
Pick the action based on is_choose_rand
127returntorch.where(is\_choose\_rand,random\_action,greedy\_action).cpu().numpy()
129defsample(self,exploration\_coefficient:float):
This doesn't need gradients
133withtorch.no\_grad():
Sample worker_steps
135fortinrange(self.worker\_steps):
Get Q_values for the current observation
137q\_value=self.model(obs\_to\_torch(self.obs))
Sample actions
139actions=self.\_sample\_action(q\_value,exploration\_coefficient)
Run sampled actions on each worker
142forw,workerinenumerate(self.workers):143worker.child.send(("step",actions[w]))
Collect information from each worker
146forw,workerinenumerate(self.workers):
Get results after executing the actions
148next\_obs,reward,done,info=worker.child.recv()
Add transition to replay buffer
151self.replay\_buffer.add(self.obs[w],actions[w],reward,next\_obs,done)
update episode information. collect episode info, which is available if an episode finished; this includes total reward and length of the episode - look at Game to see how it works.
157ifinfo:158tracker.add('reward',info['reward'])159tracker.add('length',info['length'])
update current observation
162self.obs[w]=next\_obs
164deftrain(self,beta:float):
168for\_inrange(self.train\_epochs):
Sample from priority replay buffer
170samples=self.replay\_buffer.sample(self.mini\_batch\_size,beta)
Get the predicted Q-value
172q\_value=self.model(obs\_to\_torch(samples['obs']))
Get the Q-values of the next state for Double Q-learning. Gradients shouldn't propagate for these
176withtorch.no\_grad():
Get Q(s′;θi)
178double\_q\_value=self.model(obs\_to\_torch(samples['next\_obs']))
Get Q(s′;θi−)
180target\_q\_value=self.target\_model(obs\_to\_torch(samples['next\_obs']))
Compute Temporal Difference (TD) errors, δ, and the loss, L(θ).
183td\_errors,loss=self.loss\_func(q\_value,184q\_value.new\_tensor(samples['action']),185double\_q\_value,target\_q\_value,186q\_value.new\_tensor(samples['done']),187q\_value.new\_tensor(samples['reward']),188q\_value.new\_tensor(samples['weights']))
Calculate priorities for replay buffer pi=∣δi∣+ϵ
191new\_priorities=np.abs(td\_errors.cpu().numpy())+1e-6
Update replay buffer priorities
193self.replay\_buffer.update\_priorities(samples['indexes'],new\_priorities)
Set learning rate
196forpginself.optimizer.param\_groups:197pg['lr']=self.learning\_rate()
Zero out the previously calculated gradients
199self.optimizer.zero\_grad()
Calculate gradients
201loss.backward()
Clip gradients
203torch.nn.utils.clip\_grad\_norm\_(self.model.parameters(),max\_norm=0.5)
Update parameters based on gradients
205self.optimizer.step()
207defrun\_training\_loop(self):
Last 100 episode information
213tracker.set\_queue('reward',100,True)214tracker.set\_queue('length',100,True)
Copy to target network initially
217self.target\_model.load\_state\_dict(self.model.state\_dict())218219forupdateinmonit.loop(self.updates):
ϵ, exploration fraction
221exploration=self.exploration\_coefficient(update)222tracker.add('exploration',exploration)
β for prioritized replay
224beta=self.prioritized\_replay\_beta(update)225tracker.add('beta',beta)
Sample with current policy
228self.sample(exploration)
Start training after the buffer is full
231ifself.replay\_buffer.is\_full():
Train the model
233self.train(beta)
Periodically update target network
236ifupdate%self.update\_target\_model==0:237self.target\_model.load\_state\_dict(self.model.state\_dict())
Save tracked indicators.
240tracker.save()
Add a new line to the screen periodically
242if(update+1)%1\_000==0:243logger.log()
Stop the workers
245defdestroy(self):
250forworkerinself.workers:251worker.child.send(("close",None))
254defmain():
Create the experiment
256experiment.create(name='dqn')
Configurations
259configs={
Number of updates
261'updates':1\_000\_000,
Number of epochs to train the model with sampled data.
263'epochs':8,
Number of worker processes
265'n\_workers':8,
Number of steps to run on each process for a single update
267'worker\_steps':4,
Mini batch size
269'mini\_batch\_size':32,
Target model updating interval
271'update\_target\_model':250,
Learning rate.
273'learning\_rate':FloatDynamicHyperParam(1e-4,(0,1e-3)),274}
Configurations
277experiment.configs(configs)
Initialize the trainer
280m=Trainer(\*\*configs)
Run and monitor the experiment
282withexperiment.start():283m.run\_training\_loop()
Stop the workers
285m.destroy()
289if\_\_name\_\_=="\_\_main\_\_":290main()