docs/scaling/zero3/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/scaling/zero3/ init.py)
This is an implementation of Zero-DP introduced in the paper ZeRO: Memory Optimization Towards Training A Trillion Parameter Models,
It keeps shards of the optimizer state, gradients and parameters into multiple devices/nodes. It reduces the memory consumption to Nd(2+2+K)Ψ of the original model, where Ψ is the number of parameters, Nd is the number of shards, and K is number of optimizer bytes per parameter. 2+2 are the parameter and gradient memory assuming 16-bit precision; i.e. 2 bytes per parameter and gradient. K=12 for Adam optimizer because it maintains a copy of parameters, and two moments per parameter in fp32.
The communication volume of Zero-DP is O(3Ψ). For comparison data-parallel training has a communication volume of O(2Ψ).
Although this is named Zero3 , we have only implemented the Zero-DP part of it and not the Zero-R memory optimizations which target residual memory consumption. Out implementation supports training only a subset of parameters.
This implementation is inspired by Fairscale FSDP.
Here's a script to fine-tune GPT NeoX using Zero-DP memory optimization.
32importfunctools33fromtypingimportList,Optional,Tuple3435importtorch36importtorch.distributedasdist37fromtorchimportnn
Each layer of the model (or a combination of a few consecutive layers) should be wrapped in this module.
40classZero3Layer(nn.Module):
Each shard keeps parameters in chunk list. The chunk[0] is for trainable parameters and chunk[1] is for fixed parameters.
49chunk:List[nn.Parameter]
This is the sizes of the chunks in chunk list.
51chunk\_size:List[int]
The first chunk is for trainable parameters.
53TRAINING\_PARAMS\_IDX=0
This is the list of parameters split into lists as trainable and fixed parameters.
56param\_refs:List[List[nn.Parameter]]
CUDA stream to featch parameters
59fetch\_stream:Optional[torch.cuda.Stream]
CUDA stream to backup/accumulate gradients
61backup\_stream:Optional[torch.cuda.Stream]
List of layers right before this layer
63prev\_layer:List['Zero3Layer']
List of layers right after this layer
65next\_layer:List['Zero3Layer']
The position of the current layer; used this for debugging logs
67layer\_idx:int
Whether parameters have been fetched
70is\_fetched:bool
Device of the layer
73device:torch.device
Data type of the layer
75dtype:torch.dtype
The module to be wrapped
77module:nn.Module
Number of nodes/devices the data is sharded across
79world\_size:int
module The module to be wrapped.rank The rank of the current node.world_size The number of nodes/devices the data is sharded across.device The device of the layer.dtype The data type of the layer.81def\_\_init\_\_(self,module:nn.Module,rank:int,world\_size:int,device:torch.device,dtype:torch.dtype):
89super().\_\_init\_\_()
Initialize the properties
92self.device=device93self.dtype=dtype94self.module=module95self.prev\_layer=[]96self.next\_layer=[]97self.is\_fetched=False98self.world\_size=world\_size99self.layer\_idx=-1100self.fetch\_stream=None101self.backup\_stream=None102103withtorch.no\_grad():
Collect all the parameters of the layer
105all\_param\_refs=[pforpinself.parameters()]
Store the shape of the parameters because we need it later to reconstruct them
108forpinall\_param\_refs:109p.\_orig\_shape=p.shape
All parameters should have the same type
112forpinall\_param\_refs:113assertp.dtype==dtype,"All parameters should have same dtype"
Separate parameters as trainable and fixed
116self.param\_refs=[[pforpinall\_param\_refsifp.requires\_grad],117[pforpinall\_param\_refsifnotp.requires\_grad]]118delall\_param\_refs
The rank = 0 node will calculate the size each device/node should store, and distribute the parameters accordingly.
122ifrank==0:
Merge and pad trainable (merged_params[0] ) and fixed (merged_params[1] ) parameters
124merged\_params=[self.\_merge\_and\_pad\_params(ps)forpsinself.param\_refs]
Calculate the chunk sizes of trainable and fixed params
126self.chunk\_size=[(len(p)//world\_sizeifpisnotNoneelse0)forpinmerged\_params]
Broadcast the sizes
128dist.broadcast(torch.tensor(self.chunk\_size,device=device),src=0)129else:
Create an empty tensor to receive the sizes
131chunk\_size=torch.tensor([0,0],device=device)
Receive the sizes
133dist.broadcast(chunk\_size,src=0)134self.chunk\_size=chunk\_size.tolist()
Create parameters for trainable (self.chunk[0] ) and fixed (self.chunk[1] ) parameters to be stored in current device/node
138self.chunk=[nn.Parameter(self.\_empty((s,)),requires\_grad=i==self.TRAINING\_PARAMS\_IDX)139fori,sinenumerate(self.chunk\_size)]
An empty tensor to receive the trainable and fixed parameters combined
142chunk=self.\_empty((sum(self.chunk\_size),))143144ifrank==0:
Concatenate both trainable and fixed params
146all\_params=torch.cat([p.view(world\_size,-1)forpinmerged\_params],dim=-1).view(-1)147delmerged\_params
Scatter them to all the nodes/devices
150dist.scatter(chunk,list(all\_params.split(sum(self.chunk\_size))))151delall\_params152else:
Receive the parameters
154dist.scatter(chunk)
Collect the chunk data
157chunk=chunk.split(self.chunk\_size)158fori,cinenumerate(chunk):159self.chunk[i].data[:]=c160delchunk
Cleanup the normal parameters
163self.\_cleanup\_params()
Add a backward hook. This gets called when the gradients relative to the module are computed.
166self.\_backward\_hook\_ref=self.register\_full\_backward\_hook(self.\_backward\_hook)# type: ignore
world_size .168def\_merge\_and\_pad\_params(self,params:List[nn.Parameter])-\>torch.Tensor:
Total number of parameters
173size=sum(p.shape.numel()forpinparams)
If it is not divisible by world_size , pad it
176ifsize%self.world\_size!=0:177padding\_fixed=self.world\_size-(size%self.world\_size)
Otherwise, no need to pad
179else:180padding\_fixed=0
Create an empty padding tensor
182padding=self.\_empty((padding\_fixed,))
Concatenate all the parameters and pad it
184returntorch.cat([p.view(-1)forpinparams]+[padding],dim=0)
This is what we pass on to the optimizer on the current node.
186defget\_trainable\_chunk(self)-\>List[nn.Parameter]:
Return and empty list if there are no trainable parameters
193iflen(self.chunk[self.TRAINING\_PARAMS\_IDX])==0:194return[]
Return the trainable chunk as a list
197return[self.chunk[self.TRAINING\_PARAMS\_IDX]]
199def\_empty(self,shape:Tuple[int,...])-\>torch.Tensor:
203returntorch.empty(shape,device=self.device,dtype=self.dtype)
This will release all the memory used by the layer parameters.
[email protected]\_grad()206def\_cleanup\_params(self):
Set the flag to indicate that the parameters are not fetched
214self.is\_fetched=False
Iterate through all parameters
217forpsinself.param\_refs:218forpinps:
Wait for operations on the parameters to complete before any new operations
220p.data.record\_stream(torch.cuda.current\_stream())
Check to make sure the parameter is not sharing storage with anything else
222assertp.data.storage\_offset()==0,"The tensor is not the sole occupant of the storage."
Resize the storage to 0. This will release the memory used by the parameter.
Setting p.data will not release the memory, since the autograd graph keeps a reference to it.
226p.data.storage().resize\_(0)# This is what actually clears the memory
Make sure the parameter has no gradient data
228assertp.gradisNone,'Gradients should be None'
This will fetch all the parameter data from all the nodes and rebuild the parameters on each node.
[email protected]\_grad()231deffetch\_params(self):
Skip is already fetched
239ifself.is\_fetched:240return
Set the flag
243self.is\_fetched=True
Skip if there's nothing to fetch or share.
246ifsum(self.chunk\_size)==0:247return
Use fetch_stream to fetch the parameters from all the shards
250withtorch.cuda.stream(self.fetch\_stream):
Create an empty tensor to receive the parameters
252buffer=self.\_empty((self.world\_size\*sum(self.chunk\_size),))
Split the continuous buffer into the number of nodes. These splits are views of `buffer'.
254buffers=list(buffer.split(sum(self.chunk\_size)))
Concatenate both trainable and fixed chunks
257chunk=torch.cat(self.chunk,dim=0)
Gather the parameters from all the nodes/devices
260dist.all\_gather(buffers,chunk)
Split the gathered parameters into the trainable and fixed chunks
263params=buffer.view(-1,sum(self.chunk\_size)).split(self.chunk\_size,dim=1)
Wait for the gather operation to complete and then clear the references to the buffers
265buffer.record\_stream(self.fetch\_stream)266forbinbuffers:267b.record\_stream(self.fetch\_stream)268buffer.record\_stream(self.fetch\_stream)269delbuffer270delbuffers
Reshape the trainable and fixed parameters to continuous tensors
273params=[p.reshape(-1)forpinparams]
Collect the individual parameter tensors
276forcont,psinzip(params,self.param\_refs):
If there are no parameters, skip
278ifnotps:279continue
Offset of the continuous tensor
282offset=0
Iterate through model parameters and assign the values from the continuous tensor
284forpinps:
Original parameter shape
286shape=p.\_orig\_shape# type: ignore[attr-defined]
Change the storage size of the parameter. This was set to 0 when we cleaned up the parameters.
288p.data.storage().resize\_(shape.numel())
Assign the values from the continuous tensor
290p.data[:]=cont[offset:offset+shape.numel()].reshape(shape)
Wait for the operations to complete before other operations can be performed
292p.data.record\_stream(self.fetch\_stream)
Update the offset
294offset+=shape.numel()
Wait for the operation to complete before other operations can be performed
297cont.record\_stream(self.fetch\_stream)
300delparams
302defforward(self,\*args,\*\*kwargs):
Fetch all the parameters of the current node. This gets called by the previous layer so this call is just to make sure parameters are fetched.
309self.fetch\_params()
Wait for parameter fetching to complete.
312torch.cuda.current\_stream().wait\_stream(self.fetch\_stream)
Start fetching parameters of the proceeding layers, so that they will fetch them which the current layer does its computations.
316forlayerinself.next\_layer:317layer.fetch\_params()
Add backward hooks to the parameters of the current layer if autograd is enabled.
320iftorch.is\_grad\_enabled():321self.\_add\_backward\_hooks()
Compute the outputs of the current layer
324res=self.module(\*args,\*\*kwargs)
Cleanup the parameters of the layer.
Skip cleaning up if autograd is enabled and this is the last layer in the network, because we will need to fetch the parameters again for the backward pass.
330ifnottorch.is\_grad\_enabled()orself.next\_layer:331self.\_cleanup\_params()332333returnres
335def\_add\_backward\_hooks(self):
Number of backward hooks added
341self.\_backward\_hook\_handles=0
Loop through trainable parameters of the current layer
344forpinself.param\_refs[self.TRAINING\_PARAMS\_IDX]:
Make sure a hook hasn't already been added
346assertnothasattr(p,"\_hook\_handle"),'Parameter has already been hooked'
Use expand_as to create an autograd step which we can intercept
348p\_tmp=p.expand\_as(p)
Get a handle to add the backward hook. This blog discusses about grad_acc.
351grad\_acc=p\_tmp.grad\_fn.next\_functions[0][0]
Add the backward hook
353handle=grad\_acc.register\_hook(354functools.partial(self.\_post\_backward\_hook,p))
Keep a reference to the handle
356p.\_hook\_handle=handle
Increment the number of hooks added
358self.\_backward\_hook\_handles+=1
This gets called by parameter backward hooks and the module backward hook.
360def\_backward\_event(self):
Decrement the hooks counter
368self.\_backward\_hook\_handles-=1
If all the hooks (including the module hook) have been called, then we can back up gradients and clean up the parameters.
372ifself.\_backward\_hook\_handles==-1:373self.\_backup\_grads()374self.\_cleanup\_params()
Start fetch parameters of the previous layer, because autograd will next process the gradients of it.
377forlayerinself.prev\_layer:378layer.fetch\_params()
380def\_post\_backward\_hook(self,p:nn.Parameter,\*args):
Remove the handle from the parameter
385p.\_hook\_handle.remove()# type: ignore[attr-defined]386delattr(p,"\_hook\_handle")
Handle a backward event
389self.\_backward\_event()
391def\_backward\_hook(self,\*args,\*\*kwargs):
Handle a backward event
396self.\_backward\_event()
The previous layer will start computing gradients. We need to make sure it has finished fetching params.
399torch.cuda.current\_stream().wait\_stream(self.fetch\_stream)
402returnNone
[email protected]\_grad()405def\_backup\_grads(self):
Skip if there are no trainable parameters
410ifself.chunk\_size[self.TRAINING\_PARAMS\_IDX]==0:411return
Use the backup stream to backup the gradients
414withtorch.cuda.stream(self.backup\_stream):
Buffer to store the gradients
416buffer=self.\_empty((self.world\_size\*self.chunk\_size[self.TRAINING\_PARAMS\_IDX],))
Split the continuous buffer into number of nodes. These splits are views of `buffer'.
418buffers=list(buffer.split(self.chunk\_size[self.TRAINING\_PARAMS\_IDX]))
Offset of the continuous buffer
421offset=0
Iterate through trainable parameters
423forpinself.param\_refs[self.TRAINING\_PARAMS\_IDX]:
Collect gradients
425shape=p.\_orig\_shape# type: ignore[attr-defined]426buffer[offset:offset+shape.numel()]=p.grad.view(-1)
Update the offset
428offset+=shape.numel()
Clean the gradients
430p.grad=None
Empty tensor to accumulate the gradients of the current shard
433grad=self.\_empty((self.chunk\_size[self.TRAINING\_PARAMS\_IDX],))
Accumulate the gradients of each shard. It scatters the buffers across the nodes, and each node accumulates (reduces) the tensors it receives.
436dist.reduce\_scatter(grad,buffers)
Wait for the operation to complete and then clear the references to the buffers
439forbinbuffers:440b.record\_stream(self.fetch\_stream)441buffer.record\_stream(self.fetch\_stream)442delbuffer443delbuffers
Set the chunk gradients. This is what the optimizer sees.
446self.chunk[self.TRAINING\_PARAMS\_IDX].grad=grad447delgrad
Zero3Layer layers450classZero3Sequential(nn.Module):
modules List of Zero3Layer layers454def\_\_init\_\_(self,modules:List[Zero3Layer]):
458super().\_\_init\_\_()
CUDA stream to fetch parameters
461self.fetch\_stream=torch.cuda.Stream()
CUDA stream to back up (accumulate) gradients
463self.backup\_stream=torch.cuda.Stream()
Set the streams and preceding and proceeding layers for each Zero3Layer layer
466foriinrange(len(modules)):
Set layer index
468modules[i].layer\_idx=i
Set streams
470modules[i].fetch\_stream=self.fetch\_stream471modules[i].backup\_stream=self.backup\_stream
Set proceeding layers
473ifi+1\<len(modules):474modules[i].next\_layer.append(modules[i+1])
Set preceding layers
476ifi-1\>=0:477modules[i].prev\_layer.append(modules[i-1])
Store list of modules
480self.module\_list=nn.ModuleList(modules)
482defget\_trainable\_chunk(self):
Return the list of trainable chunks from each layer
484returnsum([m.get\_trainable\_chunk()forminself.module\_list],[])
486defforward(self,x:torch.Tensor):
Make sure gradient back up is complete
488torch.cuda.current\_stream().wait\_stream(self.backup\_stream)
Forward pass
491forminself.module\_list:492x=m(x)
495returnx