docs/optimizers/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/optimizers/ init.py)
This MNIST example uses these optimizers.
This file defines a common base class for Adam and extensions of it. The base class helps use implement other optimizers with minimal code because of re-usability.
We also define a special class for L2 weight decay, so that we don't have to implement it inside each of the optimizers, and can easily extend to other weight decays like L1 without changing the optimizers.
Here are some concepts on PyTorch optimizers:
PyTorch optimizers group parameters into sets called groups. Each group can have its own hyper-parameters like learning rates.
In most common cases there will be only one group. This is when you initialize your optimizer with,
Optimizer(model.parameters())
You can define multiple parameter groups when initializing the optimizer:
Optimizer([{'params': model1.parameters()}, {'params': model2.parameters(), 'lr': 2}])
Here we pass a list of groups. Each group is a dictionary with its parameters under the key 'params'. You specify any hyper-parameters as well. If the hyper parameters are not defined they will default to the optimizer level defaults.
You can access (and even change) these groups, and their hyper-parameters with optimizer.param_groups . Most learning rate schedule implementations I've come across do access this and change 'lr'.
Optimizer maintains states (a dictionary) for each parameter (a tensor), in a dictionary optimizer.state . This is where the optimizer maintains things like exponential averages.
63fromtypingimportDict,Tuple,Any6465importtorch66fromtorchimportnn67fromtorch.optim.optimizerimportOptimizer
70classGenericAdaptiveOptimizer(Optimizer):
params is the collection of parameters or set of parameter groups.defaults a dictionary of default hyper-parameterslr is the learning rate, αbetas is the tuple (β1,β2)eps is ϵ75def\_\_init\_\_(self,params,defaults:Dict[str,Any],lr:float,betas:Tuple[float,float],eps:float):
Check the hyper-parameters
87ifnot0.0\<=lr:88raiseValueError(f"Invalid learning rate: {lr}")89ifnot0.0\<=eps:90raiseValueError(f"Invalid epsilon value: {eps}")91ifnot0.0\<=betas[0]\<1.0:92raiseValueError(f"Invalid beta parameter at index 0: {betas[0]}")93ifnot0.0\<=betas[1]\<1.0:94raiseValueError(f"Invalid beta parameter at index 1: {betas[1]}")
Add the hyper-parameters to the defaults
97defaults.update(dict(lr=lr,betas=betas,eps=eps))
Initialize the PyTorch optimizer. This will create parameter groups with the default hyper-parameters
100super().\_\_init\_\_(params,defaults)
This should be overridden with code to initialize state for parameters param . group is the parameter group dictionary to which param belongs.
102definit\_state(self,state:Dict[str,any],group:Dict[str,any],param:nn.Parameter):
109pass
This should be overridden and take the optimization step on param tensor θ, where grad is the gradient for that parameter, gt, state is the optimizer state dictionary for that parameter, and group is the parameter group dictionary param belongs to.
111defstep\_param(self,state:Dict[str,any],group:Dict[str,any],grad:torch.Tensor,param:torch.Tensor):
120pass
We have created a template method that does the common stuff every Adam based optimizer needs.
[email protected]\_grad()123defstep(self,closure=None):
Calculate loss.
🤔 I'm not sure when you need this. I guess it's if you define a function that calculates the loss, does loss.backward and return the loss, instead of calling it on your own you could pass it to optimizer.step . 🤷♂️
134loss=None135ifclosureisnotNone:136withtorch.enable\_grad():137loss=closure()
Iterate through the parameter groups
140forgroupinself.param\_groups:
Iterate through the parameters in the parameter group
142forparamingroup['params']:
Skip if the parameter has no gradient
144ifparam.gradisNone:145continue
Get the gradient tensor
147grad=param.grad.data
We don't handle sparse gradients
149ifgrad.is\_sparse:150raiseRuntimeError('GenericAdaptiveOptimizer does not support sparse gradients,'151' please consider SparseAdam instead')
Get the state for the parameter
154state=self.state[param]
Initialize the state if state is uninitialized
157iflen(state)==0:158self.init\_state(state,group,param)
Take the optimization step on the parameter
161self.step\_param(state,group,grad,param)
Return the loss, calculated from closure
164returnloss
167classWeightDecay:
weight_decay is the decay coefficientweight_decouple is a flag indicating whether to add the weight decay to the gradient or directly decay from the parameter. If added to the gradient it will go through the normal optimizer update.absolute this flag indicates whether the weight decay coefficient is absolute. This is applicable when the decay is performed directly on the parameter. If this is false the actual decay is weight_decaylearning_rate .172def\_\_init\_\_(self,weight\_decay:float=0.,weight\_decouple:bool=True,absolute:bool=False):
Check hyper-parameters
185ifnot0.0\<=weight\_decay:186raiseValueError(f"Invalid weight\_decay value: {weight\_decay}")187188self.absolute=absolute189self.weight\_decouple=weight\_decouple190self.weight\_decay=weight\_decay
Return defaults for parameter groups
192defdefaults(self):
196returndict(weight\_decay=self.weight\_decay)
198def\_\_call\_\_(self,param:torch.nn.Parameter,grad:torch.Tensor,group:Dict[str,any]):
If we are doing the decay on the parameter directly
204ifself.weight\_decouple:
If the weight decay coefficient is absolute
206ifself.absolute:207param.data.mul\_(1.0-group['weight\_decay'])
Otherwise,
209else:210param.data.mul\_(1.0-group['lr']\*group['weight\_decay'])
Return the unmodified gradient
212returngrad213else:214ifgroup['weight\_decay']!=0:
Add the weight decay to the gradient and return the modified gradient
216returngrad.add(param.data,alpha=group['weight\_decay'])217else:218returngrad