Back to Annotated Deep Learning Paper Implementations

Proximal Policy Optimization - PPO

docs/rl/ppo/index.html

latest4.9 KB
Original Source

homerlppo

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/rl/ppo/ init.py)

#

Proximal Policy Optimization - PPO

This is a PyTorch implementation of Proximal Policy Optimization - PPO.

PPO is a policy gradient method for reinforcement learning. Simple policy gradient methods do a single gradient update per sample (or a set of samples). Doing multiple gradient steps for a single sample causes problems because the policy deviates too much, producing a bad policy. PPO lets us do multiple gradient updates per sample by trying to keep the policy close to the policy that was used to sample data. It does so by clipping gradient flow if the updated policy is not close to the policy used to sample the data.

You can find an experiment that uses it here. The experiment uses Generalized Advantage Estimation.

28importtorch29fromlabml\_nn.rl.ppo.gaeimportGAE30fromtorchimportnn

#

PPO Loss

Here's how the PPO update rule is derived.

We want to maximize policy reward θmax​J(πθ​)=Eτ∼πθ​​[t=0∑∞​γtrt​] where r is the reward, π is the policy, τ is a trajectory sampled from policy, and γ is the discount factor between [0,1].

Eτ∼πθ​​[t=0∑∞​γtAπOLD​(st​,at​)]Eτ∼πθ​​[t=0∑∞​γt(QπOLD​(st​,at​)−VπOLD​(st​))]Eτ∼πθ​​[t=0∑∞​γt(rt​+VπOLD​(st+1​)−VπOLD​(st​))]Eτ∼πθ​​[t=0∑∞​γt(rt​)]−Eτ∼πθ​​[VπOLD​(s0​)]​====J(πθ​)−J(πθOLD​​)​

So, θmax​J(πθ​)=θmax​Eτ∼πθ​​[t=0∑∞​γtAπOLD​(st​,at​)]

Define discounted-future state distribution, dπ(s)=(1−γ)t=0∑∞​γtP(st​=s∣π)

Then,

J(πθ​)−J(πθOLD​​)​=Eτ∼πθ​​[t=0∑∞​γtAπOLD​(st​,at​)]=1−γ1​Es∼dπθ​,a∼πθ​​[AπOLD​(s,a)]​

Importance sampling a from πθOLD​​,

J(πθ​)−J(πθOLD​​)​=1−γ1​Es∼dπθ​,a∼πθ​​[AπOLD​(s,a)]=1−γ1​Es∼dπθ​,a∼πθOLD​​​[πθOLD​​(a∣s)πθ​(a∣s)​AπOLD​(s,a)]​

Then we assume dπθ​(s) and dπθOLD​​(s) are similar. The error we introduce to J(πθ​)−J(πθOLD​​) by this assumption is bound by the KL divergence between πθ​ and πθOLD​​. Constrained Policy Optimization shows the proof of this. I haven't read it.

J(πθ​)−J(πθOLD​​)​=1−γ1​Ea∼πθOLD​​s∼dπθ​​​[πθOLD​​(a∣s)πθ​(a∣s)​AπOLD​(s,a)]≈1−γ1​Ea∼πθOLD​​s∼dπθOLD​​​​[πθOLD​​(a∣s)πθ​(a∣s)​AπOLD​(s,a)]=1−γ1​LCPI​

33classClippedPPOLoss(nn.Module):

#

135def\_\_init\_\_(self):136super().\_\_init\_\_()

#

138defforward(self,log\_pi:torch.Tensor,sampled\_log\_pi:torch.Tensor,139advantage:torch.Tensor,clip:float)-\>torch.Tensor:

#

ratio rt​(θ)=πθOLD​​(at​∣st​)πθ​(at​∣st​)​; this is different from rewards rt​.

142ratio=torch.exp(log\_pi-sampled\_log\_pi)

#

Cliping the policy ratio

LCLIP(θ)=Eat​,st​∼πθOLD​​[min(rt​(θ)At​ˉ​,clip(rt​(θ),1−ϵ,1+ϵ)At​ˉ​)]​

The ratio is clipped to be close to 1. We take the minimum so that the gradient will only pull πθ​ towards πθOLD​​ if the ratio is not between 1−ϵ and 1+ϵ. This keeps the KL divergence between πθ​ and πθOLD​​ constrained. Large deviation can cause performance collapse; where the policy performance drops and doesn't recover because we are sampling from a bad policy.

Using the normalized advantage At​ˉ​=σ(At​^​)At​^​−μ(At​^​)​ introduces a bias to the policy gradient estimator, but it reduces variance a lot.

171clipped\_ratio=ratio.clamp(min=1.0-clip,172max=1.0+clip)173policy\_reward=torch.min(ratio\*advantage,174clipped\_ratio\*advantage)175176self.clip\_fraction=(abs((ratio-1.0))\>clip).to(torch.float).mean()177178return-policy\_reward.mean()

#

Clipped Value Function Loss

Similarly we clip the value function update also.

VCLIPπθ​​(st​)LVF(θ)​=clip(Vπθ​(st​)−Vt​^​,−ϵ,+ϵ)=21​E[max((Vπθ​(st​)−Rt​)2,(VCLIPπθ​​(st​)−Rt​)2)]​

Clipping makes sure the value function Vθ​ doesn't deviate significantly from VθOLD​​.

181classClippedValueFunctionLoss(nn.Module):

#

203defforward(self,value:torch.Tensor,sampled\_value:torch.Tensor,sampled\_return:torch.Tensor,clip:float):204clipped\_value=sampled\_value+(value-sampled\_value).clamp(min=-clip,max=clip)205vf\_loss=torch.max((value-sampled\_return)\*\*2,(clipped\_value-sampled\_return)\*\*2)206return0.5\*vf\_loss.mean()

labml.ai