Back to Annotated Deep Learning Paper Implementations

Sophia Optimizer

docs/optimizers/sophia.html

latest6.1 KB
Original Source

homeoptimizers

View code on Github

#

Sophia Optimizer

This is a PyTorch implementation of Sophia-G from paper Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training. Official implementation is available at Liuhong99/Sophia.

Sophia is more adaptive to heterogeneous curvatures than Adam, more resistant to non-convexity and rapid change of Hessian than Newton’s method, and also uses a low-cost pre-conditioner.

Sophia keeps diagonal Hessian estimates with EMA across iterations. The diagonal Hessian h^t​ is calculated every k steps.

ht​=β2​ht−k​+(1−β2​)h^t​ if t mod k=1; else ht​=ht−1​​

Sophia uses EMA of gradients mt​, only considers positive entries of the diagonal Hessian and does per-coordinate clipping to the update.

mt​θt+1​​←β1​mt−1​+(1−β1​)gt​←θt​−η⋅clip(max{ht​,ϵ}mt​​,ρ)​

where ϵ is a very small value to prevent division by 0.

Gauss-Newton-Bartlett (GNB) estimator

L^(θ)h^t​​=B1​b=1∑B​ℓCE​(f(θ,xb​),y^​b​)=B⋅∇θ​L^(θ)⊙∇θ​L^(θ)​

where xb​ are the inputs, B is the batch size (number of inputs/tokens), ℓCE​ is cross entropy loss, and y^​b​ are sampled from the logits f(θ,xb​).

Note that this hessian estimate is always positive and therefore we can replace max{ht​,ϵ} with ht​+ϵ.

Sophia with Gauss-Newton-Bartlett (GNB) estimator is Sophia-G

Here is an experiment that uses Sophia-G to train a transformer.

54fromtypingimportDict,Any,Tuple,Optional5556importtorch57fromtorchimportnn5859fromlabml\_nn.optimizersimportGenericAdaptiveOptimizer,WeightDecay

#

Sophia-G Optimizer

We extend the class GenericAdaptiveOptimizer defined in __init__.py to implement the Sophia optimizer.

62classSophia(GenericAdaptiveOptimizer):

#

Initialize the optimizer

  • params is the list of parameters
  • lr is the maximum learning rate ηρ
  • betas is a tuple of (β1​, β2​)
  • eps is ϵ
  • pho is ρ
  • weight_decay is an instance of class WeightDecay defined in __init__.py
  • defaults is a dictionary of default for group values. This is useful when you want to extend the class Adam .
70def\_\_init\_\_(self,params,71lr:float=1e-4,betas:Tuple[float,float]=(0.9,0.95),eps:float=1e-12,72rho:float=0.03,73weight\_decay:WeightDecay=WeightDecay(),74defaults:Optional[Dict[str,Any]]=None):

#

87defaults={}ifdefaultsisNoneelsedefaults88defaults.update(weight\_decay.defaults())89defaults.update(dict(rho=rho))90super().\_\_init\_\_(params,defaults,lr,betas,eps)9192self.weight\_decay=weight\_decay

#

Initialize a parameter state

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • param is the parameter tensor θt−1​
94definit\_state(self,state:Dict[str,any],group:Dict[str,any],param:nn.Parameter):

#

This is the number of optimizer steps taken on the parameter, t

104state['step']=0

#

Exponential moving average of gradients, mt​

106state['exp\_avg']=torch.zeros\_like(param,memory\_format=torch.preserve\_format)

#

Exponential moving average of Hessian diagonal, ht​

108state['hessian']=torch.zeros\_like(param,memory\_format=torch.preserve\_format)

#

Update the EMA of Hessian diagonal ht​

  • n_tokens_training_batch is the number of tokens/inputs in the batch B

h^t​ht​​=B⋅∇θ​L^(θ)⊙∇θ​L^(θ)=β2​ht−k​+(1−β2​)h^t​​

110defupdate\_hessian(self,n\_tokens\_training\_batch):

#

Iterate through parameter groups

123forgroupinself.param\_groups:

#

β2​

125\_,beta2=group['betas']

#

Iterate through parameters

127forpingroup['params']:

#

Skip parameters without gradients

129ifp.gradisNone:130continue

#

Get optimizer state

133state=self.state[p]

#

Initialize state if empty

136iflen(state)==0:137self.init\_state(state,group,p)

#

Update EMA Hessian diagonal

h^t​ht​​=B⋅∇θ​L^(θ)⊙∇θ​L^(θ)=β2​ht−k​+(1−β2​)h^t​​

145state['hessian'].mul\_(beta2).addcmul\_(p.grad,p.grad,value=(1-beta2)\*n\_tokens\_training\_batch)

#

Take an update step for a given parameter tensor

  • state is the optimizer state of the parameter (tensor)
  • group stores optimizer attributes of the parameter group
  • grad is the current gradient tensor gt​ for the parameter θt−1​
  • param is the parameter tensor θt−1​

We do the following parameter update,

θt+1​​←θt​−η⋅clip(ht​+ϵmt​​,ρ)​

147defstep\_param(self,state:Dict[str,any],group:Dict[str,any],grad:torch.Tensor,param:torch.nn.Parameter):

#

Calculate weight decay

164grad=self.weight\_decay(param,grad,group)

#

Get β1​ and β2​

167beta1,beta2=group['betas']

#

Get ρ

169rho=group['rho']

#

Get mt−1​ and ht​

172m,hessian=state['exp\_avg'],state['hessian']

#

In-place calculation of mt​ mt​←β1​mt−1​+(1−β1​)⋅gt​

176m.mul\_(beta1).add\_(grad,alpha=1-beta1)

#

Increment t the number of optimizer steps

179state['step']+=1

#

Get maximum learning rate ηρ

182lr=group['lr']

#

η

185eta=lr/rho

#

clip(ht​+ϵmt​​,ρ)

188ratio=(m/(hessian+group['eps'])).clamp(-rho,rho)

#

θt+1​←θt​−η⋅clip(ht​+ϵmt​​,ρ)

191param.data.add\_(ratio,alpha=-eta)

labml.ai