docs/optimizers/configs.html
10fromtypingimportTuple1112importtorch1314fromlabml.configsimportBaseConfigs,option,meta\_config15fromlabml\_nn.optimizersimportWeightDecay
18classOptimizerConfigs(BaseConfigs):
Optimizer
26optimizer:torch.optim.Adam
Weight decay
29weight\_decay\_obj:WeightDecay
Whether weight decay is decoupled; i.e. weight decay is not added to gradients
32weight\_decouple:bool=True
Weight decay
34weight\_decay:float=0.0
Whether weight decay is absolute or should be multiplied by learning rate
36weight\_decay\_absolute:bool=False
Whether the adam update is optimized (different epsilon)
39optimized\_adam\_update:bool=True
Parameters to be optimized
42parameters:any
Learning rate α
45learning\_rate:float=0.01
Beta values (β1,β2) for Adam
47betas:Tuple[float,float]=(0.9,0.999)
Epsilon ϵ for adam
49eps:float=1e-08
Momentum for SGD
52momentum:float=0.5
Whether to use AMSGrad
54amsgrad:bool=False
Number of warmup optimizer steps
57warmup:int=2\_000
Total number of optimizer steps (for cosine decay)
59total\_steps:int=int(1e10)
Whether to degenerate to SGD in AdaBelief
62degenerate\_to\_sgd:bool=True
Whether to use Rectified Adam in AdaBelief
65rectify:bool=True
Model embedding size for Noam optimizer
68d\_model:int6970rho:float
72def\_\_init\_\_(self):73super().\_\_init\_\_(\_primary='optimizer')747576meta\_config(OptimizerConfigs.parameters)
79@option(OptimizerConfigs.weight\_decay\_obj,'L2')80def\_weight\_decay(c:OptimizerConfigs):81returnWeightDecay(c.weight\_decay,c.weight\_decouple,c.weight\_decay\_absolute)828384@option(OptimizerConfigs.optimizer,'SGD')85def\_sgd\_optimizer(c:OptimizerConfigs):86returntorch.optim.SGD(c.parameters,c.learning\_rate,c.momentum,87weight\_decay=c.weight\_decay)888990@option(OptimizerConfigs.optimizer,'Adam')91def\_adam\_optimizer(c:OptimizerConfigs):92ifc.amsgrad:93fromlabml\_nn.optimizers.amsgradimportAMSGrad94returnAMSGrad(c.parameters,95lr=c.learning\_rate,betas=c.betas,eps=c.eps,96optimized\_update=c.optimized\_adam\_update,97weight\_decay=c.weight\_decay\_obj,amsgrad=c.amsgrad)98else:99fromlabml\_nn.optimizers.adamimportAdam100returnAdam(c.parameters,101lr=c.learning\_rate,betas=c.betas,eps=c.eps,102optimized\_update=c.optimized\_adam\_update,103weight\_decay=c.weight\_decay\_obj)104105106@option(OptimizerConfigs.optimizer,'AdamW')107def\_adam\_warmup\_optimizer(c:OptimizerConfigs):108fromlabml\_nn.optimizers.adam\_warmupimportAdamWarmup109returnAdamWarmup(c.parameters,110lr=c.learning\_rate,betas=c.betas,eps=c.eps,111weight\_decay=c.weight\_decay\_obj,amsgrad=c.amsgrad,warmup=c.warmup)112113114@option(OptimizerConfigs.optimizer,'RAdam')115def\_radam\_optimizer(c:OptimizerConfigs):116fromlabml\_nn.optimizers.radamimportRAdam117returnRAdam(c.parameters,118lr=c.learning\_rate,betas=c.betas,eps=c.eps,119weight\_decay=c.weight\_decay\_obj,amsgrad=c.amsgrad,120degenerated\_to\_sgd=c.degenerate\_to\_sgd)121122123@option(OptimizerConfigs.optimizer,'AdaBelief')124def\_ada\_belief\_optimizer(c:OptimizerConfigs):125fromlabml\_nn.optimizers.ada\_beliefimportAdaBelief126returnAdaBelief(c.parameters,127lr=c.learning\_rate,betas=c.betas,eps=c.eps,128weight\_decay=c.weight\_decay\_obj,amsgrad=c.amsgrad,129degenerate\_to\_sgd=c.degenerate\_to\_sgd,130rectify=c.rectify)131132133@option(OptimizerConfigs.optimizer,'Noam')134def\_noam\_optimizer(c:OptimizerConfigs):135fromlabml\_nn.optimizers.noamimportNoam136returnNoam(c.parameters,137lr=c.learning\_rate,betas=c.betas,eps=c.eps,138weight\_decay=c.weight\_decay\_obj,amsgrad=c.amsgrad,warmup=c.warmup,139d\_model=c.d\_model)140141142@option(OptimizerConfigs.optimizer,'Sophia')143def\_sophia\_optimizer(c:OptimizerConfigs):144fromlabml\_nn.optimizers.sophiaimportSophia145returnSophia(c.parameters,146lr=c.learning\_rate,betas=c.betas,eps=c.eps,147weight\_decay=c.weight\_decay\_obj,rho=c.rho)148149150@option(OptimizerConfigs.optimizer,'AdamWarmupCosineDecay')151def\_noam\_optimizer(c:OptimizerConfigs):152fromlabml\_nn.optimizers.adam\_warmup\_cosine\_decayimportAdamWarmupCosineDecay153returnAdamWarmupCosineDecay(c.parameters,154lr=c.learning\_rate,betas=c.betas,eps=c.eps,155weight\_decay=c.weight\_decay\_obj,amsgrad=c.amsgrad,156warmup=c.warmup,total\_steps=c.total\_steps)