docs/optimizers/ada_belief.html
This is based from AdaBelief official implementation of the paper AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients.
This is implemented in PyTorch as an extension to RAdam.
The main difference between Adam optimizer and AdaBelief is that, how it calculates the adaptive learning rate; instead of dividing by the exponential moving average of square of the gradients, AdaBelief divides by the exponential mean of variance.
mtstm^ts^tθt←β1mt−1+(1−β1)⋅gt←β2st−1+(1−β2)⋅(gt−mt)2←1−β1tmt←1−β2tst+ϵ←θt−1−α⋅s^t+ϵm^t
🤔 The paper calculates variance as (gt−mt)2, but I feel it should use the bias corrected momentum (gt−m^t)2. I guess this doesn't affect things much because bias correction is ≈1 after the initial training steps.
36fromtypingimportDict,Any3738importtorch39fromtorchimportnn4041fromlabml\_nn.optimizersimportWeightDecay42fromlabml\_nn.optimizers.radamimportRAdam
This class extends from RAdam optimizer defined in radam.py.
45classAdaBelief(RAdam):
params is the list of parameterslr is the learning rate αbetas is a tuple of (β1, β2)eps is ϵ^ or ϵ based on optimized_updateweight_decay is an instance of class WeightDecay defined in __init__.pyoptimized_update is a flag whether to optimize the bias correction of the second moment by doing it after adding ϵamsgrad is a flag indicating whether to use AMSGrad or fallback to plain Adamdegenerate_to_sgd whether to use sgd when the rectification term rt is intractablerectify is whether to use RAdam updatedefaults is a dictionary of default for group values. This is useful when you want to extend the class AdaBelief .52def\_\_init\_\_(self,params,lr=1e-3,betas=(0.9,0.999),eps=1e-16,53weight\_decay:WeightDecay=WeightDecay(),amsgrad=False,54degenerate\_to\_sgd=True,55rectify=True,defaults=None):
73defaults={}ifdefaultsisNoneelsedefaults74super().\_\_init\_\_(params,lr,betas,eps,weight\_decay,amsgrad,degenerate\_to\_sgd,defaults)75self.rectify=rectify
state is the optimizer state of the parameter (tensor)group stores optimizer attributes of the parameter groupparam is the parameter tensor θt−177definit\_state(self,state:Dict[str,any],group:Dict[str,any],param:nn.Parameter):
85state['step']=0
Exponential moving average of gradient values
87state['exp\_avg']=torch.zeros\_like(param,memory\_format=torch.preserve\_format)
Exponential moving average of variance
89state['exp\_avg\_var']=torch.zeros\_like(param,memory\_format=torch.preserve\_format)
If amsgrad flag is True for this parameter group, we maintain the maximum of exponential moving average of variance
93ifgroup['amsgrad']:
Maintains max of all exp. moving avg. of sq. grad. values
95state['max\_exp\_avg\_var']=torch.zeros\_like(param,memory\_format=torch.preserve\_format)
state is the optimizer state of the parameter (tensor)group stores optimizer attributes of the parameter groupgrad is the current gradient tensor gt for the parameter θt−197defget\_ms(self,state:Dict[str,Any],group:Dict[str,Any],grad:torch.Tensor):
Get β1 and β2
107beta1,beta2=group['betas']
Get mt−1 and st−1
110m,s=state['exp\_avg'],state['exp\_avg\_var']
In-place calculation of mt mt←β1mt−1+(1−β1)⋅gt
114m.mul\_(beta1).add\_(grad,alpha=1-beta1)
Difference between gradient and momentum
116grad\_residual=grad-m
In-place calculation of st st←β2st−1+(1−β2)⋅(gt−mt)2
119s.mul\_(beta2).addcmul\_(grad\_residual,grad\_residual,value=1-beta2)
If this parameter group is using amsgrad
122ifgroup['amsgrad']:
Get max(s1,s2,...,st−1).
124s\_max=state['max\_exp\_avg\_var']
Calculate max(s1,s2,...,st−1,st).
126torch.maximum(s\_max,s,out=s\_max)127128returnm,s\_max129else:
mt and st otherwise
131returnm,s
state is the optimizer state of the parameter (tensor)group stores optimizer attributes of the parameter groupgrad is the current gradient tensor gt for the parameter θt−1param is the parameter tensor θt−1133defstep\_param(self,state:Dict[str,any],group:Dict[str,any],grad:torch.Tensor,param:torch.nn.Parameter):
Calculate weight decay
144grad=self.weight\_decay(param,grad,group)
Get mt and vt
147m,s=self.get\_ms(state,group,grad)
Increment t the number of optimizer steps
150state['step']+=1151152ifnotself.rectify:
Perform Adam update, defined in adam.py, with st+ϵ in place of vt.
155self.adam\_update(state,group,param,m,s+group['eps'])156else:
Perform Rectified Adam update defined in radam.py, with st+ϵ in place of vt.
159self.r\_adam\_update(state,group,param,m,s+group['eps'])