docs/optimizers/sophia.html
This is a PyTorch implementation of Sophia-G from paper Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training. Official implementation is available at Liuhong99/Sophia.
Sophia is more adaptive to heterogeneous curvatures than Adam, more resistant to non-convexity and rapid change of Hessian than Newton’s method, and also uses a low-cost pre-conditioner.
Sophia keeps diagonal Hessian estimates with EMA across iterations. The diagonal Hessian h^t is calculated every k steps.
ht=β2ht−k+(1−β2)h^t if t mod k=1; else ht=ht−1
Sophia uses EMA of gradients mt, only considers positive entries of the diagonal Hessian and does per-coordinate clipping to the update.
mtθt+1←β1mt−1+(1−β1)gt←θt−η⋅clip(max{ht,ϵ}mt,ρ)
where ϵ is a very small value to prevent division by 0.
L^(θ)h^t=B1b=1∑BℓCE(f(θ,xb),y^b)=B⋅∇θL^(θ)⊙∇θL^(θ)
where xb are the inputs, B is the batch size (number of inputs/tokens), ℓCE is cross entropy loss, and y^b are sampled from the logits f(θ,xb).
Note that this hessian estimate is always positive and therefore we can replace max{ht,ϵ} with ht+ϵ.
Sophia with Gauss-Newton-Bartlett (GNB) estimator is Sophia-G
Here is an experiment that uses Sophia-G to train a transformer.
54fromtypingimportDict,Any,Tuple,Optional5556importtorch57fromtorchimportnn5859fromlabml\_nn.optimizersimportGenericAdaptiveOptimizer,WeightDecay
We extend the class GenericAdaptiveOptimizer defined in __init__.py to implement the Sophia optimizer.
62classSophia(GenericAdaptiveOptimizer):
params is the list of parameterslr is the maximum learning rate ηρbetas is a tuple of (β1, β2)eps is ϵpho is ρweight_decay is an instance of class WeightDecay defined in __init__.pydefaults is a dictionary of default for group values. This is useful when you want to extend the class Adam .70def\_\_init\_\_(self,params,71lr:float=1e-4,betas:Tuple[float,float]=(0.9,0.95),eps:float=1e-12,72rho:float=0.03,73weight\_decay:WeightDecay=WeightDecay(),74defaults:Optional[Dict[str,Any]]=None):
87defaults={}ifdefaultsisNoneelsedefaults88defaults.update(weight\_decay.defaults())89defaults.update(dict(rho=rho))90super().\_\_init\_\_(params,defaults,lr,betas,eps)9192self.weight\_decay=weight\_decay
state is the optimizer state of the parameter (tensor)group stores optimizer attributes of the parameter groupparam is the parameter tensor θt−194definit\_state(self,state:Dict[str,any],group:Dict[str,any],param:nn.Parameter):
This is the number of optimizer steps taken on the parameter, t
104state['step']=0
Exponential moving average of gradients, mt
106state['exp\_avg']=torch.zeros\_like(param,memory\_format=torch.preserve\_format)
Exponential moving average of Hessian diagonal, ht
108state['hessian']=torch.zeros\_like(param,memory\_format=torch.preserve\_format)
n_tokens_training_batch is the number of tokens/inputs in the batch Bh^tht=B⋅∇θL^(θ)⊙∇θL^(θ)=β2ht−k+(1−β2)h^t
110defupdate\_hessian(self,n\_tokens\_training\_batch):
Iterate through parameter groups
123forgroupinself.param\_groups:
β2
125\_,beta2=group['betas']
Iterate through parameters
127forpingroup['params']:
Skip parameters without gradients
129ifp.gradisNone:130continue
Get optimizer state
133state=self.state[p]
Initialize state if empty
136iflen(state)==0:137self.init\_state(state,group,p)
Update EMA Hessian diagonal
h^tht=B⋅∇θL^(θ)⊙∇θL^(θ)=β2ht−k+(1−β2)h^t
145state['hessian'].mul\_(beta2).addcmul\_(p.grad,p.grad,value=(1-beta2)\*n\_tokens\_training\_batch)
state is the optimizer state of the parameter (tensor)group stores optimizer attributes of the parameter groupgrad is the current gradient tensor gt for the parameter θt−1param is the parameter tensor θt−1We do the following parameter update,
θt+1←θt−η⋅clip(ht+ϵmt,ρ)
147defstep\_param(self,state:Dict[str,any],group:Dict[str,any],grad:torch.Tensor,param:torch.nn.Parameter):
Calculate weight decay
164grad=self.weight\_decay(param,grad,group)
Get β1 and β2
167beta1,beta2=group['betas']
Get ρ
169rho=group['rho']
Get mt−1 and ht
172m,hessian=state['exp\_avg'],state['hessian']
In-place calculation of mt mt←β1mt−1+(1−β1)⋅gt
176m.mul\_(beta1).add\_(grad,alpha=1-beta1)
Increment t the number of optimizer steps
179state['step']+=1
Get maximum learning rate ηρ
182lr=group['lr']
η
185eta=lr/rho
clip(ht+ϵmt,ρ)
188ratio=(m/(hessian+group['eps'])).clamp(-rho,rho)
θt+1←θt−η⋅clip(ht+ϵmt,ρ)
191param.data.add\_(ratio,alpha=-eta)