Back to Annotated Deep Learning Paper Implementations

AdaBelief Optimizer

docs/optimizers/ada_belief.html

latest6.1 KB
Original Source

homeoptimizers

View code on Github

#

AdaBelief Optimizer

This is based from AdaBelief official implementation of the paper AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients.

This is implemented in PyTorch as an extension to RAdam.

The main difference between Adam optimizer and AdaBelief is that, how it calculates the adaptive learning rate; instead of dividing by the exponential moving average of square of the gradients, AdaBelief divides by the exponential mean of variance.

mt​st​m^t​s^t​θt​​←β1​mt−1​+(1−β1​)⋅gt​←β2​st−1​+(1−β2​)⋅(gt​−mt​)2←1−β1​tmt​​←1−β2​tst​+ϵ​←θt−1​−α⋅s^t​​+ϵm^t​​​

🤔 The paper calculates variance as (gt​−mt​)2, but I feel it should use the bias corrected momentum (gt​−m^t​)2. I guess this doesn't affect things much because bias correction is ≈1 after the initial training steps.

36fromtypingimportDict,Any3738importtorch39fromtorchimportnn4041fromlabml\_nn.optimizersimportWeightDecay42fromlabml\_nn.optimizers.radamimportRAdam

#

AdaBelief Optimizer

This class extends from RAdam optimizer defined in radam.py.

45classAdaBelief(RAdam):

#

Initialize the optimizer

  • params is the list of parameters
  • lr is the learning rate α
  • betas is a tuple of (β1​, β2​)
  • eps is ϵ^ or ϵ based on optimized_update
  • weight_decay is an instance of class WeightDecay defined in __init__.py
  • optimized_update is a flag whether to optimize the bias correction of the second moment by doing it after adding ϵ
  • amsgrad is a flag indicating whether to use AMSGrad or fallback to plain Adam
  • degenerate_to_sgd whether to use sgd when the rectification term rt​ is intractable
  • rectify is whether to use RAdam update
  • defaults is a dictionary of default for group values. This is useful when you want to extend the class AdaBelief .
52def\_\_init\_\_(self,params,lr=1e-3,betas=(0.9,0.999),eps=1e-16,53weight\_decay:WeightDecay=WeightDecay(),amsgrad=False,54degenerate\_to\_sgd=True,55rectify=True,defaults=None):

#

73defaults={}ifdefaultsisNoneelsedefaults74super().\_\_init\_\_(params,lr,betas,eps,weight\_decay,amsgrad,degenerate\_to\_sgd,defaults)75self.rectify=rectify

#

Initialize a parameter state

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • param is the parameter tensor θt−1​
77definit\_state(self,state:Dict[str,any],group:Dict[str,any],param:nn.Parameter):

#

85state['step']=0

#

Exponential moving average of gradient values

87state['exp\_avg']=torch.zeros\_like(param,memory\_format=torch.preserve\_format)

#

Exponential moving average of variance

89state['exp\_avg\_var']=torch.zeros\_like(param,memory\_format=torch.preserve\_format)

#

If amsgrad flag is True for this parameter group, we maintain the maximum of exponential moving average of variance

93ifgroup['amsgrad']:

#

Maintains max of all exp. moving avg. of sq. grad. values

95state['max\_exp\_avg\_var']=torch.zeros\_like(param,memory\_format=torch.preserve\_format)

#

Calculate mt​ and st​ or max(s1​,s2​,...,st−1​,st​)

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • grad is the current gradient tensor gt​ for the parameter θt−1​
97defget\_ms(self,state:Dict[str,Any],group:Dict[str,Any],grad:torch.Tensor):

#

Get β1​ and β2​

107beta1,beta2=group['betas']

#

Get mt−1​ and st−1​

110m,s=state['exp\_avg'],state['exp\_avg\_var']

#

In-place calculation of mt​ mt​←β1​mt−1​+(1−β1​)⋅gt​

114m.mul\_(beta1).add\_(grad,alpha=1-beta1)

#

Difference between gradient and momentum

116grad\_residual=grad-m

#

In-place calculation of st​ st​←β2​st−1​+(1−β2​)⋅(gt​−mt​)2

119s.mul\_(beta2).addcmul\_(grad\_residual,grad\_residual,value=1-beta2)

#

If this parameter group is using amsgrad

122ifgroup['amsgrad']:

#

Get max(s1​,s2​,...,st−1​).

124s\_max=state['max\_exp\_avg\_var']

#

Calculate max(s1​,s2​,...,st−1​,st​).

126torch.maximum(s\_max,s,out=s\_max)127128returnm,s\_max129else:

#

mt​ and st​ otherwise

131returnm,s

#

Take an update step for a given parameter tensor

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • grad is the current gradient tensor gt​ for the parameter θt−1​
  • param is the parameter tensor θt−1​
133defstep\_param(self,state:Dict[str,any],group:Dict[str,any],grad:torch.Tensor,param:torch.nn.Parameter):

#

Calculate weight decay

144grad=self.weight\_decay(param,grad,group)

#

Get mt​ and vt​

147m,s=self.get\_ms(state,group,grad)

#

Increment t the number of optimizer steps

150state['step']+=1151152ifnotself.rectify:

#

Perform Adam update, defined in adam.py, with st​+ϵ in place of vt​.

155self.adam\_update(state,group,param,m,s+group['eps'])156else:

#

Perform Rectified Adam update defined in radam.py, with st​+ϵ in place of vt​.

159self.r\_adam\_update(state,group,param,m,s+group['eps'])

labml.ai