Back to Annotated Deep Learning Paper Implementations

Zero-DP Memory Optimization

docs/scaling/zero3/index.html

latest19.5 KB
Original Source

homescalingzero3

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/scaling/zero3/ init.py)

#

Zero-DP Memory Optimization

This is an implementation of Zero-DP introduced in the paper ZeRO: Memory Optimization Towards Training A Trillion Parameter Models,

It keeps shards of the optimizer state, gradients and parameters into multiple devices/nodes. It reduces the memory consumption to Nd​(2+2+K)Ψ​ of the original model, where Ψ is the number of parameters, Nd​ is the number of shards, and K is number of optimizer bytes per parameter. 2+2 are the parameter and gradient memory assuming 16-bit precision; i.e. 2 bytes per parameter and gradient. K=12 for Adam optimizer because it maintains a copy of parameters, and two moments per parameter in fp32.

The communication volume of Zero-DP is O(3Ψ). For comparison data-parallel training has a communication volume of O(2Ψ).

Although this is named Zero3 , we have only implemented the Zero-DP part of it and not the Zero-R memory optimizations which target residual memory consumption. Out implementation supports training only a subset of parameters.

This implementation is inspired by Fairscale FSDP.

Here's a script to fine-tune GPT NeoX using Zero-DP memory optimization.

32importfunctools33fromtypingimportList,Optional,Tuple3435importtorch36importtorch.distributedasdist37fromtorchimportnn

#

Zero3 Layer

Each layer of the model (or a combination of a few consecutive layers) should be wrapped in this module.

40classZero3Layer(nn.Module):

#

Each shard keeps parameters in chunk list. The chunk[0] is for trainable parameters and chunk[1] is for fixed parameters.

49chunk:List[nn.Parameter]

#

This is the sizes of the chunks in chunk list.

51chunk\_size:List[int]

#

The first chunk is for trainable parameters.

53TRAINING\_PARAMS\_IDX=0

#

This is the list of parameters split into lists as trainable and fixed parameters.

56param\_refs:List[List[nn.Parameter]]

#

CUDA stream to featch parameters

59fetch\_stream:Optional[torch.cuda.Stream]

#

CUDA stream to backup/accumulate gradients

61backup\_stream:Optional[torch.cuda.Stream]

#

List of layers right before this layer

63prev\_layer:List['Zero3Layer']

#

List of layers right after this layer

65next\_layer:List['Zero3Layer']

#

The position of the current layer; used this for debugging logs

67layer\_idx:int

#

Whether parameters have been fetched

70is\_fetched:bool

#

Device of the layer

73device:torch.device

#

Data type of the layer

75dtype:torch.dtype

#

The module to be wrapped

77module:nn.Module

#

Number of nodes/devices the data is sharded across

79world\_size:int

#

  • module The module to be wrapped.
  • rank The rank of the current node.
  • world_size The number of nodes/devices the data is sharded across.
  • device The device of the layer.
  • dtype The data type of the layer.
81def\_\_init\_\_(self,module:nn.Module,rank:int,world\_size:int,device:torch.device,dtype:torch.dtype):

#

89super().\_\_init\_\_()

#

Initialize the properties

92self.device=device93self.dtype=dtype94self.module=module95self.prev\_layer=[]96self.next\_layer=[]97self.is\_fetched=False98self.world\_size=world\_size99self.layer\_idx=-1100self.fetch\_stream=None101self.backup\_stream=None102103withtorch.no\_grad():

#

Collect all the parameters of the layer

105all\_param\_refs=[pforpinself.parameters()]

#

Store the shape of the parameters because we need it later to reconstruct them

108forpinall\_param\_refs:109p.\_orig\_shape=p.shape

#

All parameters should have the same type

112forpinall\_param\_refs:113assertp.dtype==dtype,"All parameters should have same dtype"

#

Separate parameters as trainable and fixed

116self.param\_refs=[[pforpinall\_param\_refsifp.requires\_grad],117[pforpinall\_param\_refsifnotp.requires\_grad]]118delall\_param\_refs

#

The rank = 0 node will calculate the size each device/node should store, and distribute the parameters accordingly.

122ifrank==0:

#

Merge and pad trainable (merged_params[0] ) and fixed (merged_params[1] ) parameters

124merged\_params=[self.\_merge\_and\_pad\_params(ps)forpsinself.param\_refs]

#

Calculate the chunk sizes of trainable and fixed params

126self.chunk\_size=[(len(p)//world\_sizeifpisnotNoneelse0)forpinmerged\_params]

#

Broadcast the sizes

128dist.broadcast(torch.tensor(self.chunk\_size,device=device),src=0)129else:

#

Create an empty tensor to receive the sizes

131chunk\_size=torch.tensor([0,0],device=device)

#

Receive the sizes

133dist.broadcast(chunk\_size,src=0)134self.chunk\_size=chunk\_size.tolist()

#

Create parameters for trainable (self.chunk[0] ) and fixed (self.chunk[1] ) parameters to be stored in current device/node

138self.chunk=[nn.Parameter(self.\_empty((s,)),requires\_grad=i==self.TRAINING\_PARAMS\_IDX)139fori,sinenumerate(self.chunk\_size)]

#

An empty tensor to receive the trainable and fixed parameters combined

142chunk=self.\_empty((sum(self.chunk\_size),))143144ifrank==0:

#

Concatenate both trainable and fixed params

146all\_params=torch.cat([p.view(world\_size,-1)forpinmerged\_params],dim=-1).view(-1)147delmerged\_params

#

Scatter them to all the nodes/devices

150dist.scatter(chunk,list(all\_params.split(sum(self.chunk\_size))))151delall\_params152else:

#

Receive the parameters

154dist.scatter(chunk)

#

Collect the chunk data

157chunk=chunk.split(self.chunk\_size)158fori,cinenumerate(chunk):159self.chunk[i].data[:]=c160delchunk

#

Cleanup the normal parameters

163self.\_cleanup\_params()

#

Add a backward hook. This gets called when the gradients relative to the module are computed.

166self.\_backward\_hook\_ref=self.register\_full\_backward\_hook(self.\_backward\_hook)# type: ignore

#

Merge all the parameters and pad it so that it's divisible by world_size .

168def\_merge\_and\_pad\_params(self,params:List[nn.Parameter])-\>torch.Tensor:

#

Total number of parameters

173size=sum(p.shape.numel()forpinparams)

#

If it is not divisible by world_size , pad it

176ifsize%self.world\_size!=0:177padding\_fixed=self.world\_size-(size%self.world\_size)

#

Otherwise, no need to pad

179else:180padding\_fixed=0

#

Create an empty padding tensor

182padding=self.\_empty((padding\_fixed,))

#

Concatenate all the parameters and pad it

184returntorch.cat([p.view(-1)forpinparams]+[padding],dim=0)

#

Get trainable chunk/shard of the parameters.

This is what we pass on to the optimizer on the current node.

186defget\_trainable\_chunk(self)-\>List[nn.Parameter]:

#

Return and empty list if there are no trainable parameters

193iflen(self.chunk[self.TRAINING\_PARAMS\_IDX])==0:194return[]

#

Return the trainable chunk as a list

197return[self.chunk[self.TRAINING\_PARAMS\_IDX]]

#

Create an empty tensor of the given shape.

199def\_empty(self,shape:Tuple[int,...])-\>torch.Tensor:

#

203returntorch.empty(shape,device=self.device,dtype=self.dtype)

#

Cleanup the parameter data

This will release all the memory used by the layer parameters.

[email protected]\_grad()206def\_cleanup\_params(self):

#

Set the flag to indicate that the parameters are not fetched

214self.is\_fetched=False

#

Iterate through all parameters

217forpsinself.param\_refs:218forpinps:

#

Wait for operations on the parameters to complete before any new operations

220p.data.record\_stream(torch.cuda.current\_stream())

#

Check to make sure the parameter is not sharing storage with anything else

222assertp.data.storage\_offset()==0,"The tensor is not the sole occupant of the storage."

#

Resize the storage to 0. This will release the memory used by the parameter.

Setting p.data will not release the memory, since the autograd graph keeps a reference to it.

226p.data.storage().resize\_(0)# This is what actually clears the memory

#

Make sure the parameter has no gradient data

228assertp.gradisNone,'Gradients should be None'

#

Fetch the parameters from all shards

This will fetch all the parameter data from all the nodes and rebuild the parameters on each node.

[email protected]\_grad()231deffetch\_params(self):

#

Skip is already fetched

239ifself.is\_fetched:240return

#

Set the flag

243self.is\_fetched=True

#

Skip if there's nothing to fetch or share.

246ifsum(self.chunk\_size)==0:247return

#

Use fetch_stream to fetch the parameters from all the shards

250withtorch.cuda.stream(self.fetch\_stream):

#

Create an empty tensor to receive the parameters

252buffer=self.\_empty((self.world\_size\*sum(self.chunk\_size),))

#

Split the continuous buffer into the number of nodes. These splits are views of `buffer'.

254buffers=list(buffer.split(sum(self.chunk\_size)))

#

Concatenate both trainable and fixed chunks

257chunk=torch.cat(self.chunk,dim=0)

#

Gather the parameters from all the nodes/devices

260dist.all\_gather(buffers,chunk)

#

Split the gathered parameters into the trainable and fixed chunks

263params=buffer.view(-1,sum(self.chunk\_size)).split(self.chunk\_size,dim=1)

#

Wait for the gather operation to complete and then clear the references to the buffers

265buffer.record\_stream(self.fetch\_stream)266forbinbuffers:267b.record\_stream(self.fetch\_stream)268buffer.record\_stream(self.fetch\_stream)269delbuffer270delbuffers

#

Reshape the trainable and fixed parameters to continuous tensors

273params=[p.reshape(-1)forpinparams]

#

Collect the individual parameter tensors

276forcont,psinzip(params,self.param\_refs):

#

If there are no parameters, skip

278ifnotps:279continue

#

Offset of the continuous tensor

282offset=0

#

Iterate through model parameters and assign the values from the continuous tensor

284forpinps:

#

Original parameter shape

286shape=p.\_orig\_shape# type: ignore[attr-defined]

#

Change the storage size of the parameter. This was set to 0 when we cleaned up the parameters.

288p.data.storage().resize\_(shape.numel())

#

Assign the values from the continuous tensor

290p.data[:]=cont[offset:offset+shape.numel()].reshape(shape)

#

Wait for the operations to complete before other operations can be performed

292p.data.record\_stream(self.fetch\_stream)

#

Update the offset

294offset+=shape.numel()

#

Wait for the operation to complete before other operations can be performed

297cont.record\_stream(self.fetch\_stream)

#

300delparams

#

Forward pass

302defforward(self,\*args,\*\*kwargs):

#

Fetch all the parameters of the current node. This gets called by the previous layer so this call is just to make sure parameters are fetched.

309self.fetch\_params()

#

Wait for parameter fetching to complete.

312torch.cuda.current\_stream().wait\_stream(self.fetch\_stream)

#

Start fetching parameters of the proceeding layers, so that they will fetch them which the current layer does its computations.

316forlayerinself.next\_layer:317layer.fetch\_params()

#

Add backward hooks to the parameters of the current layer if autograd is enabled.

320iftorch.is\_grad\_enabled():321self.\_add\_backward\_hooks()

#

Compute the outputs of the current layer

324res=self.module(\*args,\*\*kwargs)

#

Cleanup the parameters of the layer.

Skip cleaning up if autograd is enabled and this is the last layer in the network, because we will need to fetch the parameters again for the backward pass.

330ifnottorch.is\_grad\_enabled()orself.next\_layer:331self.\_cleanup\_params()332333returnres

#

Add backward hooks to the parameters of the current layer.

335def\_add\_backward\_hooks(self):

#

Number of backward hooks added

341self.\_backward\_hook\_handles=0

#

Loop through trainable parameters of the current layer

344forpinself.param\_refs[self.TRAINING\_PARAMS\_IDX]:

#

Make sure a hook hasn't already been added

346assertnothasattr(p,"\_hook\_handle"),'Parameter has already been hooked'

#

Use expand_as to create an autograd step which we can intercept

348p\_tmp=p.expand\_as(p)

#

Get a handle to add the backward hook. This blog discusses about grad_acc.

351grad\_acc=p\_tmp.grad\_fn.next\_functions[0][0]

#

Add the backward hook

353handle=grad\_acc.register\_hook(354functools.partial(self.\_post\_backward\_hook,p))

#

Keep a reference to the handle

356p.\_hook\_handle=handle

#

Increment the number of hooks added

358self.\_backward\_hook\_handles+=1

#

Handle a backward event

This gets called by parameter backward hooks and the module backward hook.

360def\_backward\_event(self):

#

Decrement the hooks counter

368self.\_backward\_hook\_handles-=1

#

If all the hooks (including the module hook) have been called, then we can back up gradients and clean up the parameters.

372ifself.\_backward\_hook\_handles==-1:373self.\_backup\_grads()374self.\_cleanup\_params()

#

Start fetch parameters of the previous layer, because autograd will next process the gradients of it.

377forlayerinself.prev\_layer:378layer.fetch\_params()

#

Parameter backward hook

380def\_post\_backward\_hook(self,p:nn.Parameter,\*args):

#

Remove the handle from the parameter

385p.\_hook\_handle.remove()# type: ignore[attr-defined]386delattr(p,"\_hook\_handle")

#

Handle a backward event

389self.\_backward\_event()

#

Module backward hook

391def\_backward\_hook(self,\*args,\*\*kwargs):

#

Handle a backward event

396self.\_backward\_event()

#

The previous layer will start computing gradients. We need to make sure it has finished fetching params.

399torch.cuda.current\_stream().wait\_stream(self.fetch\_stream)

#

402returnNone

#

Backup the gradients of the current layer

[email protected]\_grad()405def\_backup\_grads(self):

#

Skip if there are no trainable parameters

410ifself.chunk\_size[self.TRAINING\_PARAMS\_IDX]==0:411return

#

Use the backup stream to backup the gradients

414withtorch.cuda.stream(self.backup\_stream):

#

Buffer to store the gradients

416buffer=self.\_empty((self.world\_size\*self.chunk\_size[self.TRAINING\_PARAMS\_IDX],))

#

Split the continuous buffer into number of nodes. These splits are views of `buffer'.

418buffers=list(buffer.split(self.chunk\_size[self.TRAINING\_PARAMS\_IDX]))

#

Offset of the continuous buffer

421offset=0

#

Iterate through trainable parameters

423forpinself.param\_refs[self.TRAINING\_PARAMS\_IDX]:

#

Collect gradients

425shape=p.\_orig\_shape# type: ignore[attr-defined]426buffer[offset:offset+shape.numel()]=p.grad.view(-1)

#

Update the offset

428offset+=shape.numel()

#

Clean the gradients

430p.grad=None

#

Empty tensor to accumulate the gradients of the current shard

433grad=self.\_empty((self.chunk\_size[self.TRAINING\_PARAMS\_IDX],))

#

Accumulate the gradients of each shard. It scatters the buffers across the nodes, and each node accumulates (reduces) the tensors it receives.

436dist.reduce\_scatter(grad,buffers)

#

Wait for the operation to complete and then clear the references to the buffers

439forbinbuffers:440b.record\_stream(self.fetch\_stream)441buffer.record\_stream(self.fetch\_stream)442delbuffer443delbuffers

#

Set the chunk gradients. This is what the optimizer sees.

446self.chunk[self.TRAINING\_PARAMS\_IDX].grad=grad447delgrad

#

Sequential module for Zero3Layer layers

450classZero3Sequential(nn.Module):

#

  • modules List of Zero3Layer layers
454def\_\_init\_\_(self,modules:List[Zero3Layer]):

#

458super().\_\_init\_\_()

#

CUDA stream to fetch parameters

461self.fetch\_stream=torch.cuda.Stream()

#

CUDA stream to back up (accumulate) gradients

463self.backup\_stream=torch.cuda.Stream()

#

Set the streams and preceding and proceeding layers for each Zero3Layer layer

466foriinrange(len(modules)):

#

Set layer index

468modules[i].layer\_idx=i

#

Set streams

470modules[i].fetch\_stream=self.fetch\_stream471modules[i].backup\_stream=self.backup\_stream

#

Set proceeding layers

473ifi+1\<len(modules):474modules[i].next\_layer.append(modules[i+1])

#

Set preceding layers

476ifi-1\>=0:477modules[i].prev\_layer.append(modules[i-1])

#

Store list of modules

480self.module\_list=nn.ModuleList(modules)

#

482defget\_trainable\_chunk(self):

#

Return the list of trainable chunks from each layer

484returnsum([m.get\_trainable\_chunk()forminself.module\_list],[])

#

486defforward(self,x:torch.Tensor):

#

Make sure gradient back up is complete

488torch.cuda.current\_stream().wait\_stream(self.backup\_stream)

#

Forward pass

491forminself.module\_list:492x=m(x)

#

495returnx

labml.ai