docs/optimizers/adam.html
This is a PyTorch implementation of popular optimizer Adam from paper Adam: A Method for Stochastic Optimization.
Adam update is,
mtvtm^tv^tθt←β1mt−1+(1−β1)⋅gt←β2vt−1+(1−β2)⋅gt2←1−β1tmt←1−β2tvt←θt−1−α⋅v^t+ϵm^t
where α, β1, β2 and ϵ are scalar hyper parameters. mt and vt are first and second order moments. m^t and v^t are biased corrected moments. ϵ is used as a fix for division by zero error, but also acts as a form of a hyper-parameter that acts against variance in gradients.
Effective step taken assuming ϵ=0 is, Δt=α⋅v^tm^t This is bounded by, ∣Δt∣≤α⋅1−β21−β1 when 1−β1>1−β2 and ∣Δt∣≤α otherwise. And in most common scenarios, ∣Δt∣≈α
40importmath41fromtypingimportDict,Any,Tuple,Optional4243importtorch44fromlabmlimporttracker45fromtorchimportnn4647fromlabml\_nn.optimizersimportGenericAdaptiveOptimizer,WeightDecay
We extend the class GenericAdaptiveOptimizer defined in __init__.py to implement the Adam optimizer.
50classAdam(GenericAdaptiveOptimizer):
params is the list of parameterslr is the learning rate αbetas is a tuple of (β1, β2)eps is ϵ^ or ϵ based on optimized_updateweight_decay is an instance of class WeightDecay defined in __init__.pyoptimized_update is a flag whether to optimize the bias correction of the second moment by doing it after adding ϵdefaults is a dictionary of default for group values. This is useful when you want to extend the class Adam .58def\_\_init\_\_(self,params,59lr:float=1e-3,betas:Tuple[float,float]=(0.9,0.999),eps:float=1e-16,60weight\_decay:WeightDecay=WeightDecay(),61optimized\_update:bool=True,62defaults:Optional[Dict[str,Any]]=None):
76defaults={}ifdefaultsisNoneelsedefaults77defaults.update(weight\_decay.defaults())78super().\_\_init\_\_(params,defaults,lr,betas,eps)7980self.weight\_decay=weight\_decay81self.optimized\_update=optimized\_update
state is the optimizer state of the parameter (tensor)group stores optimizer attributes of the parameter groupparam is the parameter tensor θt−183definit\_state(self,state:Dict[str,any],group:Dict[str,any],param:nn.Parameter):
This is the number of optimizer steps taken on the parameter, t
93state['step']=0
Exponential moving average of gradients, mt
95state['exp\_avg']=torch.zeros\_like(param,memory\_format=torch.preserve\_format)
Exponential moving average of squared gradient values, vt
97state['exp\_avg\_sq']=torch.zeros\_like(param,memory\_format=torch.preserve\_format)
state is the optimizer state of the parameter (tensor)group stores optimizer attributes of the parameter groupgrad is the current gradient tensor gt for the parameter θt−199defget\_mv(self,state:Dict[str,Any],group:Dict[str,Any],grad:torch.Tensor):
Get β1 and β2
109beta1,beta2=group['betas']
Get mt−1 and vt−1
112m,v=state['exp\_avg'],state['exp\_avg\_sq']
In-place calculation of mt mt←β1mt−1+(1−β1)⋅gt
116m.mul\_(beta1).add\_(grad,alpha=1-beta1)
In-place calculation of vt vt←β2vt−1+(1−β2)⋅gt2
119v.mul\_(beta2).addcmul\_(grad,grad,value=1-beta2)120121returnm,v
This returns the modified learning rate based on the state. For Adam this is just the specified learning rate for the parameter group, α.
123defget\_lr(self,state:Dict[str,any],group:Dict[str,any]):
131returngroup['lr']
state is the optimizer state of the parameter (tensor)group stores optimizer attributes of the parameter groupparam is the parameter tensor θt−1m and v are the uncorrected first and second moments mt and vt.This computes the following
θt←θt−1−α⋅v^t+ϵm^t
Since α, β1, β2 and ϵ are scalars and others are tensors we modify this calculation to optimize the computation.
θtθtθt←θt−1−α⋅v^t+ϵm^t←θt−1−α⋅vt/(1−β2t)+ϵmt/(1−β1t)←θt−1−α1−β1t1−β2t⋅vt+ϵ^mt
where ϵ^=(1−β2t)ϵ is what we should specify as the hyper-parameter.
133defadam\_update(self,state:Dict[str,any],group:Dict[str,any],param:torch.nn.Parameter,134m:torch.Tensor,v:torch.Tensor):
Get β1 and β2
166beta1,beta2=group['betas']
Bias correction term for m^t, 1−β1t
168bias\_correction1=1-beta1\*\*state['step']
Bias correction term for v^t, 1−β2t
170bias\_correction2=1-beta2\*\*state['step']
Get learning rate
173lr=self.get\_lr(state,group)
Whether to optimize the computation
176ifself.optimized\_update:
vt+ϵ^
178denominator=v.sqrt().add\_(group['eps'])
α1−β1t1−β2t
180step\_size=lr\*math.sqrt(bias\_correction2)/bias\_correction1
θt←θt−1−α1−β1t1−β2t⋅vt+ϵ^mt
183param.data.addcdiv\_(m,denominator,value=-step\_size)
Computation without optimization
185else:
1−β2tvt+ϵ
187denominator=(v.sqrt()/math.sqrt(bias\_correction2)).add\_(group['eps'])
1−β1tα
189step\_size=lr/bias\_correction1
θt←θt−1−α⋅v^t+ϵm^t
192param.data.addcdiv\_(m,denominator,value=-step\_size)
state is the optimizer state of the parameter (tensor)group stores optimizer attributes of the parameter groupgrad is the current gradient tensor gt for the parameter θt−1param is the parameter tensor θt−1194defstep\_param(self,state:Dict[str,any],group:Dict[str,any],grad:torch.Tensor,param:torch.nn.Parameter):
Calculate weight decay
205grad=self.weight\_decay(param,grad,group)
Get mt and vt
208m,v=self.get\_mv(state,group,grad)
Increment t the number of optimizer steps
211state['step']+=1
Perform Adam update
214self.adam\_update(state,group,param,m,v)