docs/optimizers/noam.html
This is the PyTorch implementation of optimizer introduced in the paper Attention Is All You Need.
14fromtypingimportDict1516fromlabml\_nn.optimizersimportWeightDecay17fromlabml\_nn.optimizers.amsgradimportAMSGrad
This class extends from Adam optimizer defined in adam.py.
20classNoam(AMSGrad):
params is the list of parameterslr is the learning rate αbetas is a tuple of (β1, β2)eps is ϵ^ or ϵ based on optimized_updateweight_decay is an instance of class WeightDecay defined in __init__.pyamsgrad is a flag indicating whether to use AMSGrad or fallback to plain Adamwarmup number of warmup stepsd_model model size; i.e. number of dimensions in the transformerdefaults is a dictionary of default for group values. This is useful when you want to extend the class AdamWarmup .27def\_\_init\_\_(self,params,lr=1e-3,betas=(0.9,0.999),eps=1e-16,28weight\_decay:WeightDecay=WeightDecay(),29optimized\_update:bool=True,30amsgrad=False,31warmup=0,d\_model=512,defaults=None):
49defaults={}ifdefaultsisNoneelsedefaults50defaults.update(dict(warmup=warmup))51super().\_\_init\_\_(params,lr,betas,eps,weight\_decay,optimized\_update,amsgrad,defaults)52self.d\_model=d\_model
αdmodel1min(t1,w3/2t) where w is the number of warmup steps.
54defget\_lr(self,state:Dict[str,any],group:Dict[str,any]):
min(t1,w3/2t)
62factor=min(state['step']\*\*(-0.5),state['step']\*group['warmup']\*\*(-1.5))
αdmodel1min(t1,w3/2t)
64returngroup['lr']\*self.d\_model\*\*(-0.5)\*factor
67def\_test\_noam\_lr():
73importmatplotlib.pyplotasplt74importnumpyasnp75fromtorchimportnn7677model=nn.Linear(10,10)78opts=[Noam(model.parameters(),d\_model=512,warmup=4000,lr=1),79Noam(model.parameters(),d\_model=512,warmup=8000,lr=1),80Noam(model.parameters(),d\_model=2048,warmup=2000,lr=1)]81plt.plot(np.arange(1,20000),[[opt.get\_lr({'step':i},opt.defaults)foroptinopts]foriinrange(1,20000)])82plt.legend(["512:4000","512:8000","2048:2000"])83plt.title("Learning Rate")84plt.show()858687if\_\_name\_\_=='\_\_main\_\_':88\_test\_noam\_lr()