Back to Annotated Deep Learning Paper Implementations

DQN Experiment with Atari Breakout

docs/rl/dqn/experiment.html

latest10.2 KB
Original Source

homerldqn

View code on Github

#

DQN Experiment with Atari Breakout

This experiment trains a Deep Q Network (DQN) to play Atari Breakout game on OpenAI Gym. It runs the game environments on multiple processes to sample efficiently.

15importnumpyasnp16importtorch1718fromlabmlimporttracker,experiment,logger,monit19fromlabml.internal.configs.dynamic\_hyperparamimportFloatDynamicHyperParam20fromlabml\_nn.helpers.scheduleimportPiecewise21fromlabml\_nn.rl.dqnimportQFuncLoss22fromlabml\_nn.rl.dqn.modelimportModel23fromlabml\_nn.rl.dqn.replay\_bufferimportReplayBuffer24fromlabml\_nn.rl.gameimportWorker

#

Select device

27iftorch.cuda.is\_available():28device=torch.device("cuda:0")29else:30device=torch.device("cpu")

#

Scale observations from [0, 255] to [0, 1]

33defobs\_to\_torch(obs:np.ndarray)-\>torch.Tensor:

#

35returntorch.tensor(obs,dtype=torch.float32,device=device)/255.

#

Trainer

38classTrainer:

#

43def\_\_init\_\_(self,\*,44updates:int,epochs:int,45n\_workers:int,worker\_steps:int,mini\_batch\_size:int,46update\_target\_model:int,47learning\_rate:FloatDynamicHyperParam,48):

#

number of workers

50self.n\_workers=n\_workers

#

steps sampled on each update

52self.worker\_steps=worker\_steps

#

number of training iterations

54self.train\_epochs=epochs

#

number of updates

57self.updates=updates

#

size of mini batch for training

59self.mini\_batch\_size=mini\_batch\_size

#

update target network every 250 update

62self.update\_target\_model=update\_target\_model

#

learning rate

65self.learning\_rate=learning\_rate

#

exploration as a function of updates

68self.exploration\_coefficient=Piecewise(69[70(0,1.0),71(25\_000,0.1),72(self.updates/2,0.01)73],outside\_value=0.01)

#

β for replay buffer as a function of updates

76self.prioritized\_replay\_beta=Piecewise(77[78(0,0.4),79(self.updates,1)80],outside\_value=1)

#

Replay buffer with α=0.6. Capacity of the replay buffer must be a power of 2.

83self.replay\_buffer=ReplayBuffer(2\*\*14,0.6)

#

Model for sampling and training

86self.model=Model().to(device)

#

target model to get Q(s′;θi−​)

88self.target\_model=Model().to(device)

#

create workers

91self.workers=[Worker(47+i)foriinrange(self.n\_workers)]

#

initialize tensors for observations

94self.obs=np.zeros((self.n\_workers,4,84,84),dtype=np.uint8)

#

reset the workers

97forworkerinself.workers:98worker.child.send(("reset",None))

#

get the initial observations

101fori,workerinenumerate(self.workers):102self.obs[i]=worker.child.recv()

#

loss function

105self.loss\_func=QFuncLoss(0.99)

#

optimizer

107self.optimizer=torch.optim.Adam(self.model.parameters(),lr=2.5e-4)

#

ϵ-greedy Sampling

When sampling actions we use a ϵ-greedy strategy, where we take a greedy action with probabiliy 1−ϵ and take a random action with probability ϵ. We refer to ϵ as exploration_coefficient .

109def\_sample\_action(self,q\_value:torch.Tensor,exploration\_coefficient:float):

#

Sampling doesn't need gradients

119withtorch.no\_grad():

#

Sample the action with highest Q-value. This is the greedy action.

121greedy\_action=torch.argmax(q\_value,dim=-1)

#

Uniformly sample and action

123random\_action=torch.randint(q\_value.shape[-1],greedy\_action.shape,device=q\_value.device)

#

Whether to chose greedy action or the random action

125is\_choose\_rand=torch.rand(greedy\_action.shape,device=q\_value.device)\<exploration\_coefficient

#

Pick the action based on is_choose_rand

127returntorch.where(is\_choose\_rand,random\_action,greedy\_action).cpu().numpy()

#

Sample data

129defsample(self,exploration\_coefficient:float):

#

This doesn't need gradients

133withtorch.no\_grad():

#

Sample worker_steps

135fortinrange(self.worker\_steps):

#

Get Q_values for the current observation

137q\_value=self.model(obs\_to\_torch(self.obs))

#

Sample actions

139actions=self.\_sample\_action(q\_value,exploration\_coefficient)

#

Run sampled actions on each worker

142forw,workerinenumerate(self.workers):143worker.child.send(("step",actions[w]))

#

Collect information from each worker

146forw,workerinenumerate(self.workers):

#

Get results after executing the actions

148next\_obs,reward,done,info=worker.child.recv()

#

Add transition to replay buffer

151self.replay\_buffer.add(self.obs[w],actions[w],reward,next\_obs,done)

#

update episode information. collect episode info, which is available if an episode finished; this includes total reward and length of the episode - look at Game to see how it works.

157ifinfo:158tracker.add('reward',info['reward'])159tracker.add('length',info['length'])

#

update current observation

162self.obs[w]=next\_obs

#

Train the model

164deftrain(self,beta:float):

#

168for\_inrange(self.train\_epochs):

#

Sample from priority replay buffer

170samples=self.replay\_buffer.sample(self.mini\_batch\_size,beta)

#

Get the predicted Q-value

172q\_value=self.model(obs\_to\_torch(samples['obs']))

#

Get the Q-values of the next state for Double Q-learning. Gradients shouldn't propagate for these

176withtorch.no\_grad():

#

Get Q(s′;θi​)

178double\_q\_value=self.model(obs\_to\_torch(samples['next\_obs']))

#

Get Q(s′;θi−​)

180target\_q\_value=self.target\_model(obs\_to\_torch(samples['next\_obs']))

#

Compute Temporal Difference (TD) errors, δ, and the loss, L(θ).

183td\_errors,loss=self.loss\_func(q\_value,184q\_value.new\_tensor(samples['action']),185double\_q\_value,target\_q\_value,186q\_value.new\_tensor(samples['done']),187q\_value.new\_tensor(samples['reward']),188q\_value.new\_tensor(samples['weights']))

#

Calculate priorities for replay buffer pi​=∣δi​∣+ϵ

191new\_priorities=np.abs(td\_errors.cpu().numpy())+1e-6

#

Update replay buffer priorities

193self.replay\_buffer.update\_priorities(samples['indexes'],new\_priorities)

#

Set learning rate

196forpginself.optimizer.param\_groups:197pg['lr']=self.learning\_rate()

#

Zero out the previously calculated gradients

199self.optimizer.zero\_grad()

#

Calculate gradients

201loss.backward()

#

Clip gradients

203torch.nn.utils.clip\_grad\_norm\_(self.model.parameters(),max\_norm=0.5)

#

Update parameters based on gradients

205self.optimizer.step()

#

Run training loop

207defrun\_training\_loop(self):

#

Last 100 episode information

213tracker.set\_queue('reward',100,True)214tracker.set\_queue('length',100,True)

#

Copy to target network initially

217self.target\_model.load\_state\_dict(self.model.state\_dict())218219forupdateinmonit.loop(self.updates):

#

ϵ, exploration fraction

221exploration=self.exploration\_coefficient(update)222tracker.add('exploration',exploration)

#

β for prioritized replay

224beta=self.prioritized\_replay\_beta(update)225tracker.add('beta',beta)

#

Sample with current policy

228self.sample(exploration)

#

Start training after the buffer is full

231ifself.replay\_buffer.is\_full():

#

Train the model

233self.train(beta)

#

Periodically update target network

236ifupdate%self.update\_target\_model==0:237self.target\_model.load\_state\_dict(self.model.state\_dict())

#

Save tracked indicators.

240tracker.save()

#

Add a new line to the screen periodically

242if(update+1)%1\_000==0:243logger.log()

#

Destroy

Stop the workers

245defdestroy(self):

#

250forworkerinself.workers:251worker.child.send(("close",None))

#

254defmain():

#

Create the experiment

256experiment.create(name='dqn')

#

Configurations

259configs={

#

Number of updates

261'updates':1\_000\_000,

#

Number of epochs to train the model with sampled data.

263'epochs':8,

#

Number of worker processes

265'n\_workers':8,

#

Number of steps to run on each process for a single update

267'worker\_steps':4,

#

Mini batch size

269'mini\_batch\_size':32,

#

Target model updating interval

271'update\_target\_model':250,

#

Learning rate.

273'learning\_rate':FloatDynamicHyperParam(1e-4,(0,1e-3)),274}

#

Configurations

277experiment.configs(configs)

#

Initialize the trainer

280m=Trainer(\*\*configs)

#

Run and monitor the experiment

282withexperiment.start():283m.run\_training\_loop()

#

Stop the workers

285m.destroy()

#

Run it

289if\_\_name\_\_=="\_\_main\_\_":290main()

labml.ai