Back to Annotated Deep Learning Paper Implementations

Adam Optimizer

docs/optimizers/adam.html

latest7.3 KB
Original Source

homeoptimizers

View code on Github

#

Adam Optimizer

This is a PyTorch implementation of popular optimizer Adam from paper Adam: A Method for Stochastic Optimization.

Adam update is,

mt​vt​m^t​v^t​θt​​←β1​mt−1​+(1−β1​)⋅gt​←β2​vt−1​+(1−β2​)⋅gt​2←1−β1​tmt​​←1−β2​tvt​​←θt−1​−α⋅v^t​​+ϵm^t​​​

where α, β1​, β2​ and ϵ are scalar hyper parameters. mt​ and vt​ are first and second order moments. m^t​ and v^t​ are biased corrected moments. ϵ is used as a fix for division by zero error, but also acts as a form of a hyper-parameter that acts against variance in gradients.

Effective step taken assuming ϵ=0 is, Δt=α⋅v^t​m^t​​ This is bounded by, ∣Δt∣≤α⋅1−β2​​1−β1​​ when 1−β1​>1−β2​​ and ∣Δt∣≤α otherwise. And in most common scenarios, ∣Δt∣≈α

40importmath41fromtypingimportDict,Any,Tuple,Optional4243importtorch44fromlabmlimporttracker45fromtorchimportnn4647fromlabml\_nn.optimizersimportGenericAdaptiveOptimizer,WeightDecay

#

Adam Optimizer

We extend the class GenericAdaptiveOptimizer defined in __init__.py to implement the Adam optimizer.

50classAdam(GenericAdaptiveOptimizer):

#

Initialize the optimizer

  • params is the list of parameters
  • lr is the learning rate α
  • betas is a tuple of (β1​, β2​)
  • eps is ϵ^ or ϵ based on optimized_update
  • weight_decay is an instance of class WeightDecay defined in __init__.py
  • optimized_update is a flag whether to optimize the bias correction of the second moment by doing it after adding ϵ
  • defaults is a dictionary of default for group values. This is useful when you want to extend the class Adam .
58def\_\_init\_\_(self,params,59lr:float=1e-3,betas:Tuple[float,float]=(0.9,0.999),eps:float=1e-16,60weight\_decay:WeightDecay=WeightDecay(),61optimized\_update:bool=True,62defaults:Optional[Dict[str,Any]]=None):

#

76defaults={}ifdefaultsisNoneelsedefaults77defaults.update(weight\_decay.defaults())78super().\_\_init\_\_(params,defaults,lr,betas,eps)7980self.weight\_decay=weight\_decay81self.optimized\_update=optimized\_update

#

Initialize a parameter state

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • param is the parameter tensor θt−1​
83definit\_state(self,state:Dict[str,any],group:Dict[str,any],param:nn.Parameter):

#

This is the number of optimizer steps taken on the parameter, t

93state['step']=0

#

Exponential moving average of gradients, mt​

95state['exp\_avg']=torch.zeros\_like(param,memory\_format=torch.preserve\_format)

#

Exponential moving average of squared gradient values, vt​

97state['exp\_avg\_sq']=torch.zeros\_like(param,memory\_format=torch.preserve\_format)

#

Calculate mt​ and and vt​

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • grad is the current gradient tensor gt​ for the parameter θt−1​
99defget\_mv(self,state:Dict[str,Any],group:Dict[str,Any],grad:torch.Tensor):

#

Get β1​ and β2​

109beta1,beta2=group['betas']

#

Get mt−1​ and vt−1​

112m,v=state['exp\_avg'],state['exp\_avg\_sq']

#

In-place calculation of mt​ mt​←β1​mt−1​+(1−β1​)⋅gt​

116m.mul\_(beta1).add\_(grad,alpha=1-beta1)

#

In-place calculation of vt​ vt​←β2​vt−1​+(1−β2​)⋅gt​2

119v.mul\_(beta2).addcmul\_(grad,grad,value=1-beta2)120121returnm,v

#

Get learning-rate

This returns the modified learning rate based on the state. For Adam this is just the specified learning rate for the parameter group, α.

123defget\_lr(self,state:Dict[str,any],group:Dict[str,any]):

#

131returngroup['lr']

#

Do the Adam parameter update

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • param is the parameter tensor θt−1​
  • m and v are the uncorrected first and second moments mt​ and vt​.

This computes the following

θt​​←θt−1​−α⋅v^t​​+ϵm^t​​​

Since α, β1​, β2​ and ϵ are scalars and others are tensors we modify this calculation to optimize the computation.

θt​θt​θt​​←θt−1​−α⋅v^t​​+ϵm^t​​←θt−1​−α⋅vt​/(1−β2​t)​+ϵmt​/(1−β1​t)​←θt−1​−α1−β1​t1−β2​t​​⋅vt​​+ϵ^mt​​​

where ϵ^=(1−β2​t)ϵ is what we should specify as the hyper-parameter.

133defadam\_update(self,state:Dict[str,any],group:Dict[str,any],param:torch.nn.Parameter,134m:torch.Tensor,v:torch.Tensor):

#

Get β1​ and β2​

166beta1,beta2=group['betas']

#

Bias correction term for m^t​, 1−β1​t

168bias\_correction1=1-beta1\*\*state['step']

#

Bias correction term for v^t​, 1−β2​t

170bias\_correction2=1-beta2\*\*state['step']

#

Get learning rate

173lr=self.get\_lr(state,group)

#

Whether to optimize the computation

176ifself.optimized\_update:

#

vt​​+ϵ^

178denominator=v.sqrt().add\_(group['eps'])

#

α1−β1​t1−β2​t​​

180step\_size=lr\*math.sqrt(bias\_correction2)/bias\_correction1

#

θt​←θt−1​−α1−β1​t1−β2​t​​⋅vt​​+ϵ^mt​​

183param.data.addcdiv\_(m,denominator,value=-step\_size)

#

Computation without optimization

185else:

#

1−β2​t​vt​​​+ϵ

187denominator=(v.sqrt()/math.sqrt(bias\_correction2)).add\_(group['eps'])

#

1−β1​tα​

189step\_size=lr/bias\_correction1

#

θt​←θt−1​−α⋅v^t​​+ϵm^t​​

192param.data.addcdiv\_(m,denominator,value=-step\_size)

#

Take an update step for a given parameter tensor

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • grad is the current gradient tensor gt​ for the parameter θt−1​
  • param is the parameter tensor θt−1​
194defstep\_param(self,state:Dict[str,any],group:Dict[str,any],grad:torch.Tensor,param:torch.nn.Parameter):

#

Calculate weight decay

205grad=self.weight\_decay(param,grad,group)

#

Get mt​ and vt​

208m,v=self.get\_mv(state,group,grad)

#

Increment t the number of optimizer steps

211state['step']+=1

#

Perform Adam update

214self.adam\_update(state,group,param,m,v)

labml.ai