docs/optimizers/adam_fp16.html
10fromtypingimportDict,Tuple,Optional,Any1112importtorch13fromtorchimportnn14fromtorch.optimimportOptimizer15fromtorch.cuda.ampimportgrad\_scaler16fromcollectionsimportdefaultdict,abc1718fromlabml\_nn.optimizersimportWeightDecay19fromlabml\_nn.optimizers.adamimportAdam
We extend Adam Optimizer but use FP32 to store gradients and moments.
22classAdamFP16(Adam):
29def\_\_init\_\_(self,params,lr:float=1e-3,betas:Tuple[float,float]=(0.9,0.999),eps:float=1e-16,30weight\_decay:WeightDecay=WeightDecay(),optimized\_update:bool=True,31defaults:Optional[Dict[str,Any]]=None):
Parameter to store 32 bit gradients. This get populated by the GradScaler defined below.
33self.grad\_fp32={}
Call the Adam Optimizer initializer
35super().\_\_init\_\_(params,lr,betas,eps,weight\_decay,optimized\_update,defaults)
state is the optimizer state of the parameter (tensor)group stores optimizer attributes of the parameter groupparam is the parameter tensor θt−1All the state tensors use FP32.
37definit\_state(self,state:Dict[str,any],group:Dict[str,any],param:nn.Parameter):
This is the number of optimizer steps taken on the parameter, t
49state['step']=0
Exponential moving average of gradients, mt
51state['exp\_avg']=torch.zeros\_like(param,memory\_format=torch.preserve\_format,dtype=torch.float)
Exponential moving average of squared gradient values, vt
53state['exp\_avg\_sq']=torch.zeros\_like(param,memory\_format=torch.preserve\_format,dtype=torch.float)
Maintain a FP32 copy of the parameters
55state['fp32\_copy']=param.to(torch.float)
state is the optimizer state of the parameter (tensor)group stores optimizer attributes of the parameter groupgrad is the current gradient tensor gt for the parameter θt−1param is the parameter tensor θt−157defstep\_param(self,state:Dict[str,any],group:Dict[str,any],grad:torch.Tensor,param:torch.nn.Parameter):
Get the FP32 parameters
68param\_fp32=state['fp32\_copy']
Get the FP32 gradients if available
70grad\_fp32=self.grad\_fp32.get(param,None)71ifgrad\_fp32isnotNone:72delself.grad\_fp32[param]73grad=grad\_fp3274else:
Otherwise, convert the gradients to FP32
76grad=grad.to(torch.float)
Calculate weight decay
79grad=self.weight\_decay(param\_fp32,grad,group)
Get mt and vt
82m,v=self.get\_mv(state,group,grad)
Increment t the number of optimizer steps
85state['step']+=1
Perform Adam update
88self.adam\_update(state,group,param\_fp32,m,v)
Set the parameters
91param.data=param\_fp32.to(param.dtype)
We extend PyTorch gradient scaler to use FP32 gradients.
94classGradScalerFP16(grad\_scaler.GradScaler):
101def\_unscale\_grads\_(self,optimizer:Optimizer,inv\_scale:torch.Tensor,found\_inf:torch.Tensor,102allow\_fp16:bool)-\>Dict[torch.device,torch.Tensor]:103per\_device\_inv\_scale=grad\_scaler.\_MultiDeviceReplicator(inv\_scale)104per\_device\_found\_inf=grad\_scaler.\_MultiDeviceReplicator(found\_inf)105106per\_device\_and\_dtype\_grads=defaultdict(lambda:defaultdict(list))# type: ignore[var-annotated]107108withtorch.no\_grad():
Loop through parameters
110forgroupinoptimizer.param\_groups:111forparamingroup["params"]:
Skip non-trainable parameters
113ifparam.gradisNone:114continue
Not implemented for sparse tensors
116ifparam.grad.is\_sparse:117raiseNotImplementedError
If we are using the AdamFP16 optimizer set optimizer.grad_fp32[param] to the FP32 gradients
120ifisinstance(optimizer,AdamFP16):121grad=param.grad.to(torch.float)122optimizer.grad\_fp32[param]=grad
Otherwise, do not convert the gradients to FP32
124else:125grad=param.grad126127per\_device\_and\_dtype\_grads[grad.device][grad.dtype].append(grad)
Unscale all the gradients
130fordevice,per\_dtype\_gradsinper\_device\_and\_dtype\_grads.items():131forgradsinper\_dtype\_grads.values():132torch.\_amp\_foreach\_non\_finite\_check\_and\_unscale\_(grads,133per\_device\_found\_inf.get(device),134per\_device\_inv\_scale.get(device))
136returnper\_device\_found\_inf.\_per\_device\_tensors