Back to Annotated Deep Learning Paper Implementations

Rectified Adam (RAdam) optimizer

docs/optimizers/radam.html

latest9.6 KB
Original Source

homeoptimizers

View code on Github

#

Rectified Adam (RAdam) optimizer

This implementation is based on the official implementation of the paper On the Variance of the Adaptive Learning Rate and Beyond.

We have implemented it in PyTorch as an extension to our AMSGrad implementation thus requiring only the modifications to be implemented.

Adam optimizer sometimes converges to a bad local optima during the initial stages of the training; especially when training transformers. Researches use warmups to counter this; for the the initial training steps (warm-up stage) they use a low learning rate. This paper identifies the problem to be the high variance of adaptive learning rate during initial stages of training, and counters it using a new rectification term to reduce variance.

The paper also evaluates two variance reduction mechanisms: _ Adam-2k : Only compute the adaptive learning rate (vt​ in Adam) during the first 2k steps, without changing parameters or calculating momentum (mt​)._ Adam-eps : Adam with large ϵ≈10−4.

Rectified Adam

Let σ(g1​,...,gt​) and ψ(g1​,...,gt​) be the functions to calculate momentum and adaptive learning rate. For Adam, they are

σ(g1​,...,gt​)ψ(g1​,...,gt​)​=1−β1​t(1−β1​)∑i=1t​β1​t−igi​​=(1−β2​)∑i=1t​β2​t−igi2​1−β2​t​​​

Exponential moving average as simple moving average

The distribution of exponential moving average can be approximated as a simple moving average.

p(1−β2​t(1−β2​)∑i=1t​β2​t−igi2​​)≈p(f(t,β2​)∑i=1f(t,β2​)​gt+1−i2​​)​

Here we are taking the simple moving average of the last f(t,β2​) gradients. f(t,β2​) satisfies the following,

1−β2​t(1−β2​)∑i=1t​β2​t−i⋅i​=f(t,β2​)∑i=1f(t,β2​)​(t+1−i)​​

which gives, f(t,β2​)=1−β2​2​−1−1−β2​t2tβ2​t​

Scaled inverse chi-squared

From above we have p(ψ2(g1​,...,gt​))≈p(f(t,β2​)∑i=1f(t,β2​)​gt+1−i2​​) where gi​∼N(0,σ2). Note that sigma here is the standard deviation and different from σ(.) for momentum.

Scaled inverse chi-squared is the distribution of squared inverse of mean of p normal distributions. p(f(t,β2​)∑i=1f(t,β2​)​gt+1−i2​​)∼Scale-invX2(ρ,σ21​) where ρ=f(t,β2​).

Rectification

They prove that variance of ψ(.) decreases with ρ when ψ2(.)∼Scale-invX2(ρ,σ21​).

Therefore the variance is minimized at maximal ρ which is ρ∞​=1−β2​2​−1. Let the minimum variance be Cvar​

In order to ensure that the adaptive learning rate ψ(.) has consistent variance, we rectify the variance with r

r=Var[ψ(.)]Cvar​​​​

Approximating Var[ψ(.)]

They estimate Var[ψ(.)]≈4E[ψ2(.)Var[ψ2(.)]​ based on first order expansion of ψ2(.)​ 🤪 I didn't get how it was derived.

From Scale-invX2 distribution we have,

E[ψ2(.)]Var[ψ2(.)]​=ρ−2ρ/σ2​=(ρ−2)2(ρ−2)2ρ/σ4​​

which gives, Var[ψ(.)]≈2(ρ−2)(ρ−4)σ2ρ​

Rectification term

We have

rVar[ψ(.)]​=Var[ψ(.)]Cvar​​​≈2(ρ−2)(ρ−4)σ2ρ​​

where Cvar​ is Var[ψ(.)] for ρ∞​. Lt ρ and step t be ρt​, and rt​ be the rectification term at step t.

Cvar​Var[ψ(g1​,...,gt​)]​≈2(ρ∞​−2)(ρ∞​−4)σ2ρ∞​​≈2(ρt​−2)(ρt​−4)σ2ρt​​​

This gives,

rt​​=(ρ∞​−2)(ρ∞​−4)ρt​(ρt​−2)(ρt​−4)ρ∞​​​​

139importmath140fromtypingimportDict,Optional141142importtorch143144fromlabml\_nn.optimizersimportWeightDecay145fromlabml\_nn.optimizers.amsgradimportAMSGrad

#

Rectified Adam Optimizer

This class extends from AMSAdam optimizer defined in amsadam.py.

148classRAdam(AMSGrad):

#

Initialize the optimizer

  • params is the list of parameters
  • lr is the learning rate α
  • betas is a tuple of (β1​, β2​)
  • eps is ϵ^ or ϵ based on optimized_update
  • weight_decay is an instance of class WeightDecay defined in __init__.py
  • optimized_update is a flag whether to optimize the bias correction of the second moment by doing it after adding ϵ
  • amsgrad is a flag indicating whether to use AMSGrad or fallback to plain Adam
  • degenerate_to_sgd whether to use sgd when the rectification term rt​ is intractable.
  • defaults is a dictionary of default for group values. This is useful when you want to extend the class RAdam .
155def\_\_init\_\_(self,params,lr=1e-3,betas=(0.9,0.999),eps=1e-8,156weight\_decay:WeightDecay=WeightDecay(),157optimized\_update:bool=True,158amsgrad=False,159degenerated\_to\_sgd=True,defaults=None):

#

175self.degenerated\_to\_sgd=degenerated\_to\_sgd176super().\_\_init\_\_(params,lr,betas,eps,weight\_decay,optimized\_update,amsgrad,defaults)

#

Take an update step for a given parameter tensor

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • grad is the current gradient tensor gt​ for the parameter θt−1​
  • param is the parameter tensor θt−1​
178defstep\_param(self,state:Dict[str,any],group:Dict[str,any],grad:torch.Tensor,param:torch.nn.Parameter):

#

Calculate weight decay

189grad=self.weight\_decay(param,grad,group)

#

Get mt​ and vt​; i.e. σ(.) and ψ(.) without bias correction

192m,v=self.get\_mv(state,group,grad)

#

Calculate t the number of optimizer steps

195state['step']+=1

#

Perform RAdam update

198self.r\_adam\_update(state,group,param,m,v)

#

Calculate rectification term rt​

200@staticmethod201defcalc\_rectification\_term(beta2:float,step:int)-\>Optional[float]:

#

β2​t

207beta2\_t=beta2\*\*step

#

ρ∞​=1−β2​2​−1

209rho\_inf=2/(1-beta2)-1

#

ρt​=1−β2​2​−1−1−β2​t2tβ2​t​

211rho=rho\_inf-2\*step\*beta2\_t/(1-beta2\_t)

#

rt​ is tractable when ρt​>=4. We are being a little more conservative since it's an approximated value

215ifrho\>=5:

#

rt​=(ρ∞​−2)(ρ∞​−4)ρt​(ρt​−2)(ρt​−4)ρ∞​​​

217r2=(rho-4)/(rho\_inf-4)\*(rho-2)/rho\*rho\_inf/(rho\_inf-2)218returnmath.sqrt(r2)219else:220returnNone

#

Do the RAdam parameter update

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • param is the parameter tensor θt−1​
  • m and v are the uncorrected first and second moments mt​ and vt​; i.e. σ(.) and ψ(.) without bias correction
222defr\_adam\_update(self,state:Dict[str,any],group:Dict[str,any],param:torch.nn.Parameter,223m:torch.Tensor,v:torch.Tensor):

#

Get β1​ and β2​

235beta1,beta2=group['betas']

#

Bias correction term for m^t​, 1−β1​t

237bias\_correction1=1-beta1\*\*state['step']

#

Bias correction term for v^t​, 1−β2​t

239bias\_correction2=1-beta2\*\*state['step']240241r=self.calc\_rectification\_term(beta2,state['step'])

#

Get learning rate

244lr=self.get\_lr(state,group)

#

If rt​ is intractable

247ifrisnotNone:

#

Whether to optimize the computation by combining scalar computations

249ifself.optimized\_update:

#

Denominator vt​​+ϵ^

251denominator=v.sqrt().add\_(group['eps'])

#

Step size αrt​​∗1−β1​t1−β2​t​​

253step\_size=lr\*math.sqrt(bias\_correction2)\*r/bias\_correction1

#

Update parameters θt​←θt−1​−αrt​​1−β1​t1−β2​t​​⋅vt​​+ϵ^mt​​

256param.data.addcdiv\_(m,denominator,value=-step\_size)

#

Computation without optimization

258else:

#

Denominator 1−β2​t​vt​​​+ϵ

260denominator=(v.sqrt()/math.sqrt(bias\_correction2)).add\_(group['eps'])

#

Step size 1−β1​tαrt​​​

262step\_size=lr\*r/bias\_correction1

#

Update parameters θt​←θt−1​−αrt​​⋅v^t​​+ϵm^t​​

265param.data.addcdiv\_(m,denominator,value=-step\_size)

#

If rt​ is intractable do a SGD with momentum

268elifself.degenerated\_to\_sgd:

#

Step size 1−β1​tα​

270step\_size=lr/bias\_correction1

#

Update parameters θt​←θt−1​−α⋅m^t​

273param.data.add\_(m,alpha=-step\_size)

#

Plot rt​ against t for various β2​

276def\_test\_rectification\_term():

#

282importmatplotlib.pyplotasplt283importnumpyasnp284285beta2=[0.9999,0.999,0.99,0.9,0.8,0.6,0.5]286plt.plot(np.arange(1,5\_000),[[RAdam.calc\_rectification\_term(b,i)forbinbeta2]foriinrange(1,5\_000)])287plt.legend(beta2)288plt.title("Optimizer")289plt.show()290291292if\_\_name\_\_=='\_\_main\_\_':293\_test\_rectification\_term()

labml.ai