Back to Annotated Deep Learning Paper Implementations

Generalized Advantage Estimation (GAE)

docs/rl/ppo/gae.html

latest2.2 KB
Original Source

homerlppo

View code on Github

#

Generalized Advantage Estimation (GAE)

This is a PyTorch implementation of paper Generalized Advantage Estimation.

You can find an experiment that uses it here.

15importnumpyasnp

#

18classGAE:

#

19def\_\_init\_\_(self,n\_workers:int,worker\_steps:int,gamma:float,lambda\_:float):20self.lambda\_=lambda\_21self.gamma=gamma22self.worker\_steps=worker\_steps23self.n\_workers=n\_workers

#

Calculate advantages

At(1)​^​At(2)​^​...At(∞)​^​​=rt​+γV(st+1​)−V(s)=rt​+γrt+1​+γ2V(st+2​)−V(s)=rt​+γrt+1​+γ2rt+2​+...−V(s)​

At(1)​^​ is high bias, low variance, whilst At(∞)​^​ is unbiased, high variance.

We take a weighted average of At(k)​^​ to balance bias and variance. This is called Generalized Advantage Estimation. At​^​=AtGAE​^​=∑k​wk​∑k​wk​At(k)​^​​ We set wk​=λk−1, this gives clean calculation for At​^​

δt​At​^​​=rt​+γV(st+1​)−V(st​)=δt​+γλδt+1​+...+(γλ)T−t+1δT−1​=δt​+γλAt+1​^​​

25def\_\_call\_\_(self,done:np.ndarray,rewards:np.ndarray,values:np.ndarray)-\>np.ndarray:

#

advantages table

59advantages=np.zeros((self.n\_workers,self.worker\_steps),dtype=np.float32)60last\_advantage=0

#

V(st+1​)

63last\_value=values[:,-1]6465fortinreversed(range(self.worker\_steps)):

#

mask if episode completed after step t

67mask=1.0-done[:,t]68last\_value=last\_value\*mask69last\_advantage=last\_advantage\*mask

#

δt​

71delta=rewards[:,t]+self.gamma\*last\_value-values[:,t]

#

At​^​=δt​+γλAt+1​^​

74last\_advantage=delta+self.gamma\*self.lambda\_\*last\_advantage

#

77advantages[:,t]=last\_advantage7879last\_value=values[:,t]

#

At​^​

82returnadvantages

labml.ai