Back to Annotated Deep Learning Paper Implementations

Noam Optimizer

docs/optimizers/noam.html

latest2.9 KB
Original Source

homeoptimizers

View code on Github

#

Noam Optimizer

This is the PyTorch implementation of optimizer introduced in the paper Attention Is All You Need.

14fromtypingimportDict1516fromlabml\_nn.optimizersimportWeightDecay17fromlabml\_nn.optimizers.amsgradimportAMSGrad

#

Noam Optimizer

This class extends from Adam optimizer defined in adam.py.

20classNoam(AMSGrad):

#

Initialize the optimizer

  • params is the list of parameters
  • lr is the learning rate α
  • betas is a tuple of (β1​, β2​)
  • eps is ϵ^ or ϵ based on optimized_update
  • weight_decay is an instance of class WeightDecay defined in __init__.py
  • 'optimized_update' is a flag whether to optimize the bias correction of the second moment by doing it after adding ϵ
  • amsgrad is a flag indicating whether to use AMSGrad or fallback to plain Adam
  • warmup number of warmup steps
  • d_model model size; i.e. number of dimensions in the transformer
  • defaults is a dictionary of default for group values. This is useful when you want to extend the class AdamWarmup .
27def\_\_init\_\_(self,params,lr=1e-3,betas=(0.9,0.999),eps=1e-16,28weight\_decay:WeightDecay=WeightDecay(),29optimized\_update:bool=True,30amsgrad=False,31warmup=0,d\_model=512,defaults=None):

#

49defaults={}ifdefaultsisNoneelsedefaults50defaults.update(dict(warmup=warmup))51super().\_\_init\_\_(params,lr,betas,eps,weight\_decay,optimized\_update,amsgrad,defaults)52self.d\_model=d\_model

#

Get learning-rate

αdmodel​​1​min(t​1​,w3/2t​) where w is the number of warmup steps.

54defget\_lr(self,state:Dict[str,any],group:Dict[str,any]):

#

min(t​1​,w3/2t​)

62factor=min(state['step']\*\*(-0.5),state['step']\*group['warmup']\*\*(-1.5))

#

αdmodel​​1​min(t​1​,w3/2t​)

64returngroup['lr']\*self.d\_model\*\*(-0.5)\*factor

#

Plot learning rate for different warmups and model sizes

67def\_test\_noam\_lr():

#

73importmatplotlib.pyplotasplt74importnumpyasnp75fromtorchimportnn7677model=nn.Linear(10,10)78opts=[Noam(model.parameters(),d\_model=512,warmup=4000,lr=1),79Noam(model.parameters(),d\_model=512,warmup=8000,lr=1),80Noam(model.parameters(),d\_model=2048,warmup=2000,lr=1)]81plt.plot(np.arange(1,20000),[[opt.get\_lr({'step':i},opt.defaults)foroptinopts]foriinrange(1,20000)])82plt.legend(["512:4000","512:8000","2048:2000"])83plt.title("Learning Rate")84plt.show()858687if\_\_name\_\_=='\_\_main\_\_':88\_test\_noam\_lr()

labml.ai