docs/si/optimizers/adam_warmup_cosine_decay.html
මෙය AMSGrad ප්රශස්තකරණය පුළුල් කරන අතර උනුසුම් අවධියක් එක් කරයි.
11importmath12fromtypingimportDict1314fromlabml\_nn.optimizersimportWeightDecay15fromlabml\_nn.optimizers.amsgradimportAMSGrad
මෙමපන්තිය AMSGrad ප්රශස්තකරණයෙන් අර්ථ දක්වා ඇත amsgrad.py.
18classAdamWarmupCosineDecay(AMSGrad):
params යනු පරාමිතීන් ලැයිස්තුවයිlr යනු ඉගෙනුම් අනුපාතයයි αbetas (β1, β2) ක tuple වේepsϵ^ හෝ මත ϵ පදනම් වේ optimized_updateweight_decay``WeightDecay අර්ථ දක්වා ඇති පන්තියේ අවස්ථාවකි __init__.pyamsgrad ආදම් සරල කිරීම සඳහා AMSGrad හෝ වැටීම භාවිතා කළ යුතුද යන්න දැක්වෙන ධජයකිwarmup උනුසුම් පියවර ගණනtotal_steps මුළු පියවර ගණන. කොසයින් ක්ෂය වීම මේ වන විට 0 දක්වා ළඟා වේ, නමුත් අප ගන්නා lr නිසා 10% ක රැඳී සිටියි α∗max(0.1,decay)defaults කණ්ඩායම් අගයන් සඳහා පෙරනිමි ශබ්ද කෝෂයකි. ඔබට පන්තිය දීර් extend කිරීමට අවශ්ය විට මෙය ප්රයෝජනවත් AdamWarmup වේ.27def\_\_init\_\_(self,params,lr=1e-3,betas=(0.9,0.999),eps=1e-16,28weight\_decay:WeightDecay=WeightDecay(),29optimized\_update:bool=True,30amsgrad=False,warmup=0,total\_steps=1e10,defaults=None):
49defaults={}ifdefaultsisNoneelsedefaults50defaults.update(dict(warmup=warmup,total\_steps=total\_steps))51super().\_\_init\_\_(params,lr,betas,eps,weight\_decay,optimized\_update,amsgrad,defaults)
αmin(1,wt) උණුසුම් කිරීමේ පියවර ගණන w කොහේද?
53defget\_lr(self,state:Dict[str,any],group:Dict[str,any]):
අපිඋනුසුම් අවධියක සිටී නම්
61ifgroup['warmup']\>state['step']:
සිට 0 රේඛීයව වැඩිවන ඉගෙනුම් අනුපාතය α
63return1e-8+state['step']\*group['lr']/group['warmup']64else:
නිරන්තරඉගෙනුම් අනුපාතය α
66progress=(state['step']-group['warmup'])/max(1,group['total\_steps']-group['warmup'])67returngroup['lr']\*max(0.1,0.5\*(1.0+math.cos(math.pi\*progress)))
70def\_test\_lr():
76importmatplotlib.pyplotasplt77importnumpyasnp78fromtorchimportnn7980model=nn.Linear(10,10)81opt=AdamWarmupCosineDecay(model.parameters(),warmup=5000,lr=1e-4,total\_steps=4e6)82steps=20\_00083plt.plot(np.arange(1,steps),[opt.get\_lr({'step':i},opt.defaults)foriinrange(1,steps)])84plt.legend(["5000:4e6","5000:2e6","5000:1e6"])85plt.title("Learning Rate")86plt.show()8788steps=int(6e6)89step\_size=100090plt.plot(np.arange(1,steps,step\_size),[opt.get\_lr({'step':i},opt.defaults)foriinrange(1,steps,step\_size)])91plt.legend(["5000:4e6","5000:2e6","5000:1e6"])92plt.title("Learning Rate")93plt.show()949596if\_\_name\_\_=='\_\_main\_\_':97\_test\_lr()