docs/rl/ppo/gae.html
This is a PyTorch implementation of paper Generalized Advantage Estimation.
You can find an experiment that uses it here.
15importnumpyasnp
18classGAE:
19def\_\_init\_\_(self,n\_workers:int,worker\_steps:int,gamma:float,lambda\_:float):20self.lambda\_=lambda\_21self.gamma=gamma22self.worker\_steps=worker\_steps23self.n\_workers=n\_workers
At(1)^At(2)^...At(∞)^=rt+γV(st+1)−V(s)=rt+γrt+1+γ2V(st+2)−V(s)=rt+γrt+1+γ2rt+2+...−V(s)
At(1)^ is high bias, low variance, whilst At(∞)^ is unbiased, high variance.
We take a weighted average of At(k)^ to balance bias and variance. This is called Generalized Advantage Estimation. At^=AtGAE^=∑kwk∑kwkAt(k)^ We set wk=λk−1, this gives clean calculation for At^
δtAt^=rt+γV(st+1)−V(st)=δt+γλδt+1+...+(γλ)T−t+1δT−1=δt+γλAt+1^
25def\_\_call\_\_(self,done:np.ndarray,rewards:np.ndarray,values:np.ndarray)-\>np.ndarray:
advantages table
59advantages=np.zeros((self.n\_workers,self.worker\_steps),dtype=np.float32)60last\_advantage=0
V(st+1)
63last\_value=values[:,-1]6465fortinreversed(range(self.worker\_steps)):
mask if episode completed after step t
67mask=1.0-done[:,t]68last\_value=last\_value\*mask69last\_advantage=last\_advantage\*mask
δt
71delta=rewards[:,t]+self.gamma\*last\_value-values[:,t]
At^=δt+γλAt+1^
74last\_advantage=delta+self.gamma\*self.lambda\_\*last\_advantage
77advantages[:,t]=last\_advantage7879last\_value=values[:,t]
At^
82returnadvantages