Back to Annotated Deep Learning Paper Implementations

Adam Optimizer with Warmup and Cosine Decay

docs/optimizers/adam_warmup_cosine_decay.html

latest3.3 KB
Original Source

homeoptimizers

View code on Github

#

Adam Optimizer with Warmup and Cosine Decay

This extends AMSGrad optimizer and adds a warmup stage.

11importmath12fromtypingimportDict1314fromlabml\_nn.optimizersimportWeightDecay15fromlabml\_nn.optimizers.amsgradimportAMSGrad

#

Adam Optimizer with Warmup and Cosine Decay

This class extends from AMSGrad optimizer defined in amsgrad.py.

18classAdamWarmupCosineDecay(AMSGrad):

#

Initialize the optimizer

  • params is the list of parameters
  • lr is the learning rate α
  • betas is a tuple of (β1​, β2​)
  • eps is ϵ^ or ϵ based on optimized_update
  • weight_decay is an instance of class WeightDecay defined in __init__.py
  • 'optimized_update' is a flag whether to optimize the bias correction of the second moment by doing it after adding ϵ
  • amsgrad is a flag indicating whether to use AMSGrad or fallback to plain Adam
  • warmup number of warmup steps
  • total_steps total number of steps. Cosine decay reaches 0 at this, but stays at 10% of lr because we take α∗max(0.1,decay)
  • defaults is a dictionary of default for group values. This is useful when you want to extend the class AdamWarmup .
27def\_\_init\_\_(self,params,lr=1e-3,betas=(0.9,0.999),eps=1e-16,28weight\_decay:WeightDecay=WeightDecay(),29optimized\_update:bool=True,30amsgrad=False,warmup=0,total\_steps=1e10,defaults=None):

#

49defaults={}ifdefaultsisNoneelsedefaults50defaults.update(dict(warmup=warmup,total\_steps=total\_steps))51super().\_\_init\_\_(params,lr,betas,eps,weight\_decay,optimized\_update,amsgrad,defaults)

#

Get learning-rate

αmin(1,wt​) where w is the number of warmup steps.

53defget\_lr(self,state:Dict[str,any],group:Dict[str,any]):

#

If we are in warmup stage

61ifgroup['warmup']\>state['step']:

#

A linearly increasing learning rate from 0 to α

63return1e-8+state['step']\*group['lr']/group['warmup']64else:

#

Constant learning rate α

66progress=(state['step']-group['warmup'])/max(1,group['total\_steps']-group['warmup'])67returngroup['lr']\*max(0.1,0.5\*(1.0+math.cos(math.pi\*progress)))

#

Plot learning rate for different warmups and model sizes

70def\_test\_lr():

#

76importmatplotlib.pyplotasplt77importnumpyasnp78fromtorchimportnn7980model=nn.Linear(10,10)81opt=AdamWarmupCosineDecay(model.parameters(),warmup=5000,lr=1e-4,total\_steps=4e6)82steps=20\_00083plt.plot(np.arange(1,steps),[opt.get\_lr({'step':i},opt.defaults)foriinrange(1,steps)])84plt.legend(["5000:4e6","5000:2e6","5000:1e6"])85plt.title("Learning Rate")86plt.show()8788steps=int(6e6)89step\_size=100090plt.plot(np.arange(1,steps,step\_size),[opt.get\_lr({'step':i},opt.defaults)foriinrange(1,steps,step\_size)])91plt.legend(["5000:4e6","5000:2e6","5000:1e6"])92plt.title("Learning Rate")93plt.show()949596if\_\_name\_\_=='\_\_main\_\_':97\_test\_lr()

labml.ai