docs/optimizers/adam_warmup.html
This extends AMSGrad optimizer and adds a warmup stage.
12fromtypingimportDict1314fromlabml\_nn.optimizersimportWeightDecay15fromlabml\_nn.optimizers.amsgradimportAMSGrad
This class extends from AMSGrad optimizer defined in amsgrad.py.
18classAdamWarmup(AMSGrad):
params is the list of parameterslr is the learning rate αbetas is a tuple of (β1, β2)eps is ϵ^ or ϵ based on optimized_updateweight_decay is an instance of class WeightDecay defined in __init__.pyamsgrad is a flag indicating whether to use AMSGrad or fallback to plain Adamwarmup number of warmup stepsdefaults is a dictionary of default for group values. This is useful when you want to extend the class AdamWarmup .24def\_\_init\_\_(self,params,lr=1e-3,betas=(0.9,0.999),eps=1e-16,25weight\_decay:WeightDecay=WeightDecay(),26optimized\_update:bool=True,27amsgrad=False,warmup=0,defaults=None):
44defaults={}ifdefaultsisNoneelsedefaults45defaults.update(dict(warmup=warmup))46super().\_\_init\_\_(params,lr,betas,eps,weight\_decay,optimized\_update,amsgrad,defaults)
αmin(1,wt) where w is the number of warmup steps.
48defget\_lr(self,state:Dict[str,any],group:Dict[str,any]):
If we are in warmup stage
56ifgroup['warmup']\>state['step']:
A linearly increasing learning rate from 0 to α
58return1e-8+state['step']\*group['lr']/group['warmup']59else:
Constant learning rate α
61returngroup['lr']