docs/optimizers/amsgrad.html
This is a PyTorch implementation of the paper On the Convergence of Adam and Beyond.
We implement this as an extension to our Adam optimizer implementation. The implementation it self is really small since it's very similar to Adam.
We also have an implementation of the synthetic example described in the paper where Adam fails to converge.
18fromtypingimportDict1920importtorch21fromtorchimportnn2223fromlabml\_nn.optimizersimportWeightDecay24fromlabml\_nn.optimizers.adamimportAdam
This class extends from Adam optimizer defined in adam.py. Adam optimizer is extending the class GenericAdaptiveOptimizer defined in __init__.py.
27classAMSGrad(Adam):
params is the list of parameterslr is the learning rate αbetas is a tuple of (β1, β2)eps is ϵ^ or ϵ based on optimized_updateweight_decay is an instance of class WeightDecay defined in __init__.pyamsgrad is a flag indicating whether to use AMSGrad or fallback to plain Adamdefaults is a dictionary of default for group values. This is useful when you want to extend the class Adam .35def\_\_init\_\_(self,params,lr=1e-3,betas=(0.9,0.999),eps=1e-16,36weight\_decay:WeightDecay=WeightDecay(),37optimized\_update:bool=True,38amsgrad=True,defaults=None):
53defaults={}ifdefaultsisNoneelsedefaults54defaults.update(dict(amsgrad=amsgrad))5556super().\_\_init\_\_(params,lr,betas,eps,weight\_decay,optimized\_update,defaults)
state is the optimizer state of the parameter (tensor)group stores optimizer attributes of the parameter groupparam is the parameter tensor θt−158definit\_state(self,state:Dict[str,any],group:Dict[str,any],param:nn.Parameter):
Call init_state of Adam optimizer which we are extending
68super().init\_state(state,group,param)
If amsgrad flag is True for this parameter group, we maintain the maximum of exponential moving average of squared gradient
72ifgroup['amsgrad']:73state['max\_exp\_avg\_sq']=torch.zeros\_like(param,memory\_format=torch.preserve\_format)
state is the optimizer state of the parameter (tensor)group stores optimizer attributes of the parameter groupgrad is the current gradient tensor gt for the parameter θt−175defget\_mv(self,state:Dict[str,any],group:Dict[str,any],grad:torch.Tensor):
Get mt and vt from Adam
85m,v=super().get\_mv(state,group,grad)
If this parameter group is using amsgrad
88ifgroup['amsgrad']:
Get max(v1,v2,...,vt−1).
🗒 The paper uses the notation v^t for this, which we don't use that here because it confuses with the Adam's usage of the same notation for bias corrected exponential moving average.
94v\_max=state['max\_exp\_avg\_sq']
Calculate max(v1,v2,...,vt−1,vt).
🤔 I feel you should be taking / maintaining the max of the bias corrected second exponential average of squared gradient. But this is how it's implemented in PyTorch also. I guess it doesn't really matter since bias correction only increases the value and it only makes an actual difference during the early few steps of the training.
103torch.maximum(v\_max,v,out=v\_max)104105returnm,v\_max106else:
Fall back to Adam if the parameter group is not using amsgrad
108returnm,v
This is the synthetic experiment described in the paper, that shows a scenario where Adam fails.
The paper (and Adam) formulates the problem of optimizing as minimizing the expected value of a function, E[f(θ)] with respect to the parameters θ. In the stochastic training setting we do not get hold of the function f it self; that is, when you are optimizing a NN f would be the function on entire batch of data. What we actually evaluate is a mini-batch so the actual function is realization of the stochastic f. This is why we are talking about an expected value. So let the function realizations be f1,f2,...,fT for each time step of training.
We measure the performance of the optimizer as the regret, R(T)=t=1∑T[ft(θt)−ft(θ∗)] where θt is the parameters at time step t, and θ∗ is the optimal parameters that minimize E[f(θ)].
Now lets define the synthetic problem,
ft(x)={1010x,−10x,for tmod101=1otherwise
where −1≤x≤+1. The optimal solution is x=−1.
This code will try running Adam and AMSGrad on this problem.
111def\_synthetic\_experiment(is\_adam:bool):
Define x parameter
153x=nn.Parameter(torch.tensor([.0]))
Optimal, x∗=−1
155x\_star=nn.Parameter(torch.tensor([-1]),requires\_grad=False)
157deffunc(t:int,x\_:nn.Parameter):
161ift%101==1:162return(1010\*x\_).sum()163else:164return(-10\*x\_).sum()
Initialize the relevant optimizer
167ifis\_adam:168optimizer=Adam([x],lr=1e-2,betas=(0.9,0.99))169else:170optimizer=AMSGrad([x],lr=1e-2,betas=(0.9,0.99))
R(T)
172total\_regret=0173174fromlabmlimportmonit,tracker,experiment
Create experiment to record results
177withexperiment.record(name='synthetic',comment='Adam'ifis\_adamelse'AMSGrad'):
Run for 107 steps
179forstepinmonit.loop(10\_000\_000):
ft(θt)−ft(θ∗)
181regret=func(step,x)-func(step,x\_star)
R(T)=∑t=1T[ft(θt)−ft(θ∗)]
183total\_regret+=regret.item()
Track results every 1,000 steps
185if(step+1)%1000==0:186tracker.save(loss=regret,x=x,regret=total\_regret/(step+1))
Calculate gradients
188regret.backward()
Optimize
190optimizer.step()
Clear gradients
192optimizer.zero\_grad()
Make sure −1≤x≤+1
195x.data.clamp\_(-1.,+1.)196197198if\_\_name\_\_=='\_\_main\_\_':
Run the synthetic experiment is Adam. You can see that Adam converges at x=+1
201\_synthetic\_experiment(True)
Run the synthetic experiment is AMSGrad You can see that AMSGrad converges to true optimal x=−1
204\_synthetic\_experiment(False)