Back to Annotated Deep Learning Paper Implementations

Adam Optimizer with Warmup

docs/optimizers/adam_warmup.html

latest2.1 KB
Original Source

homeoptimizers

View code on Github

#

Adam Optimizer with Warmup

This extends AMSGrad optimizer and adds a warmup stage.

12fromtypingimportDict1314fromlabml\_nn.optimizersimportWeightDecay15fromlabml\_nn.optimizers.amsgradimportAMSGrad

#

Adam Optimizer with Warmup

This class extends from AMSGrad optimizer defined in amsgrad.py.

18classAdamWarmup(AMSGrad):

#

Initialize the optimizer

  • params is the list of parameters
  • lr is the learning rate α
  • betas is a tuple of (β1​, β2​)
  • eps is ϵ^ or ϵ based on optimized_update
  • weight_decay is an instance of class WeightDecay defined in __init__.py
  • 'optimized_update' is a flag whether to optimize the bias correction of the second moment by doing it after adding ϵ
  • amsgrad is a flag indicating whether to use AMSGrad or fallback to plain Adam
  • warmup number of warmup steps
  • defaults is a dictionary of default for group values. This is useful when you want to extend the class AdamWarmup .
24def\_\_init\_\_(self,params,lr=1e-3,betas=(0.9,0.999),eps=1e-16,25weight\_decay:WeightDecay=WeightDecay(),26optimized\_update:bool=True,27amsgrad=False,warmup=0,defaults=None):

#

44defaults={}ifdefaultsisNoneelsedefaults45defaults.update(dict(warmup=warmup))46super().\_\_init\_\_(params,lr,betas,eps,weight\_decay,optimized\_update,amsgrad,defaults)

#

Get learning-rate

αmin(1,wt​) where w is the number of warmup steps.

48defget\_lr(self,state:Dict[str,any],group:Dict[str,any]):

#

If we are in warmup stage

56ifgroup['warmup']\>state['step']:

#

A linearly increasing learning rate from 0 to α

58return1e-8+state['step']\*group['lr']/group['warmup']59else:

#

Constant learning rate α

61returngroup['lr']

labml.ai