Back to Annotated Deep Learning Paper Implementations

Prioritized Experience Replay Buffer

docs/rl/dqn/replay_buffer.html

latest8.9 KB
Original Source

homerldqn

View code on Github

#

Prioritized Experience Replay Buffer

This implements paper Prioritized experience replay, using a binary segment tree.

15importrandom1617importnumpyasnp

#

Buffer for Prioritized Experience Replay

Prioritized experience replay samples important transitions more frequently. The transitions are prioritized by the Temporal Difference error (td error), δ.

We sample transition i with probability, P(i)=∑k​pkα​pi​α​ where α is a hyper-parameter that determines how much prioritization is used, with α=0 corresponding to uniform case. pi​ is the priority.

We use proportional prioritization pi​=∣δi​∣+ϵ where δi​ is the temporal difference for transition i.

We correct the bias introduced by prioritized replay using importance-sampling (IS) weights wi​=(N1​P(i)1​)β in the loss function. This fully compensates when β=1. We normalize weights by maxi​wi​1​ for stability. Unbiased nature is most important towards the convergence at end of training. Therefore we increase β towards end of training.

Binary Segment Tree

We use a binary segment tree to efficiently calculate ∑ki​pkα​, the cumulative probability, which is needed to sample. We also use a binary segment tree to find minpi​α, which is needed for maxi​wi​1​. We can also use a min-heap for this. Binary Segment Tree lets us calculate these in O(logn) time, which is way more efficient that the naive O(n) approach.

This is how a binary segment tree works for sum; it is similar for minimum. Let xi​ be the list of N values we want to represent. Let bi,j​ be the jth node of the ith row in the binary tree. That is two children of node bi,j​ are bi+1,2j​ and bi+1,2j+1​.

The leaf nodes on row D=⌈1+log2​N⌉ will have values of x. Every node keeps the sum of the two child nodes. That is, the root node keeps the sum of the entire array of values. The left and right children of the root node keep the sum of the first half of the array and the sum of the second half of the array, respectively. And so on...

bi,j​=k=(j−1)∗2D−i+1∑j∗2D−i​xk​

Number of nodes in row i, Ni​=⌈D−i+1N​⌉ This is equal to the sum of nodes in all rows above i. So we can use a single array a to store the tree, where, bi,j​→aNi​+j​

Then child nodes of ai​ are a2i​ and a2i+1​. That is, ai​=a2i​+a2i+1​

This way of maintaining binary trees is very easy to program. Note that we are indexing starting from 1.

We use the same structure to compute the minimum.

20classReplayBuffer:

#

Initialize

90def\_\_init\_\_(self,capacity,alpha):

#

We use a power of 2 for capacity because it simplifies the code and debugging

95self.capacity=capacity

#

α

97self.alpha=alpha

#

Maintain segment binary trees to take sum and find minimum over a range

100self.priority\_sum=[0for\_inrange(2\*self.capacity)]101self.priority\_min=[float('inf')for\_inrange(2\*self.capacity)]

#

Current max priority, p, to be assigned to new transitions

104self.max\_priority=1.

#

Arrays for buffer

107self.data={108'obs':np.zeros(shape=(capacity,4,84,84),dtype=np.uint8),109'action':np.zeros(shape=capacity,dtype=np.int32),110'reward':np.zeros(shape=capacity,dtype=np.float32),111'next\_obs':np.zeros(shape=(capacity,4,84,84),dtype=np.uint8),112'done':np.zeros(shape=capacity,dtype=np.bool)113}

#

We use cyclic buffers to store data, and next_idx keeps the index of the next empty slot

116self.next\_idx=0

#

Size of the buffer

119self.size=0

#

Add sample to queue

121defadd(self,obs,action,reward,next\_obs,done):

#

Get next available slot

127idx=self.next\_idx

#

store in the queue

130self.data['obs'][idx]=obs131self.data['action'][idx]=action132self.data['reward'][idx]=reward133self.data['next\_obs'][idx]=next\_obs134self.data['done'][idx]=done

#

Increment next available slot

137self.next\_idx=(idx+1)%self.capacity

#

Calculate the size

139self.size=min(self.capacity,self.size+1)

#

pi​α, new samples get max_priority

142priority\_alpha=self.max\_priority\*\*self.alpha

#

Update the two segment trees for sum and minimum

144self.\_set\_priority\_min(idx,priority\_alpha)145self.\_set\_priority\_sum(idx,priority\_alpha)

#

Set priority in binary segment tree for minimum

147def\_set\_priority\_min(self,idx,priority\_alpha):

#

Leaf of the binary tree

153idx+=self.capacity154self.priority\_min[idx]=priority\_alpha

#

Update tree, by traversing along ancestors. Continue until the root of the tree.

158whileidx\>=2:

#

Get the index of the parent node

160idx//=2

#

Value of the parent node is the minimum of it's two children

162self.priority\_min[idx]=min(self.priority\_min[2\*idx],self.priority\_min[2\*idx+1])

#

Set priority in binary segment tree for sum

164def\_set\_priority\_sum(self,idx,priority):

#

Leaf of the binary tree

170idx+=self.capacity

#

Set the priority at the leaf

172self.priority\_sum[idx]=priority

#

Update tree, by traversing along ancestors. Continue until the root of the tree.

176whileidx\>=2:

#

Get the index of the parent node

178idx//=2

#

Value of the parent node is the sum of it's two children

180self.priority\_sum[idx]=self.priority\_sum[2\*idx]+self.priority\_sum[2\*idx+1]

#

∑k​pkα​

182def\_sum(self):

#

The root node keeps the sum of all values

188returnself.priority\_sum[1]

#

mink​pkα​

190def\_min(self):

#

The root node keeps the minimum of all values

196returnself.priority\_min[1]

#

Find largest i such that ∑k=1i​pkα​≤P

198deffind\_prefix\_sum\_idx(self,prefix\_sum):

#

Start from the root

204idx=1205whileidx\<self.capacity:

#

If the sum of the left branch is higher than required sum

207ifself.priority\_sum[idx\*2]\>prefix\_sum:

#

Go to left branch of the tree

209idx=2\*idx210else:

#

Otherwise go to right branch and reduce the sum of left branch from required sum

213prefix\_sum-=self.priority\_sum[idx\*2]214idx=2\*idx+1

#

We are at the leaf node. Subtract the capacity by the index in the tree to get the index of actual value

218returnidx-self.capacity

#

Sample from buffer

220defsample(self,batch\_size,beta):

#

Initialize samples

226samples={227'weights':np.zeros(shape=batch\_size,dtype=np.float32),228'indexes':np.zeros(shape=batch\_size,dtype=np.int32)229}

#

Get sample indexes

232foriinrange(batch\_size):233p=random.random()\*self.\_sum()234idx=self.find\_prefix\_sum\_idx(p)235samples['indexes'][i]=idx

#

mini​P(i)=∑k​pkα​mini​pi​α​

238prob\_min=self.\_min()/self.\_sum()

#

maxi​wi​=(N1​mini​P(i)1​)β

240max\_weight=(prob\_min\*self.size)\*\*(-beta)241242foriinrange(batch\_size):243idx=samples['indexes'][i]

#

P(i)=∑k​pkα​pi​α​

245prob=self.priority\_sum[idx+self.capacity]/self.\_sum()

#

wi​=(N1​P(i)1​)β

247weight=(prob\*self.size)\*\*(-beta)

#

Normalize by maxi​wi​1​, which also cancels off the N1​ term

250samples['weights'][i]=weight/max\_weight

#

Get samples data

253fork,vinself.data.items():254samples[k]=v[samples['indexes']]255256returnsamples

#

Update priorities

258defupdate\_priorities(self,indexes,priorities):

#

263foridx,priorityinzip(indexes,priorities):

#

Set current max priority

265self.max\_priority=max(self.max\_priority,priority)

#

Calculate pi​α

268priority\_alpha=priority\*\*self.alpha

#

Update the trees

270self.\_set\_priority\_min(idx,priority\_alpha)271self.\_set\_priority\_sum(idx,priority\_alpha)

#

Whether the buffer is full

273defis\_full(self):

#

277returnself.capacity==self.size

labml.ai