docs/neox/model.html
Here is the code for layers of GPT-NeoX model and the code to load 20B checkpoint.
The method load_state in the layers load the checkpoints of that layer. The checkpoint loading helpers are on checkpoint.py
16importcopy17importmath18fromtypingimportDict,Optional,Set,Callable,Any,Generator,Tuple1920importtorch21fromtorchimportnn22fromtorch.cuda.ampimportautocast2324fromlabmlimportmonit,logger25fromlabml.loggerimportText26fromlabml\_nn.neoximportcheckpoint27fromlabml\_nn.neox.utils.cacheimportget\_cache
30classNeoXModule(nn.Module):
31defload\_state(self,p1:Dict[str,torch.Tensor],p2:Dict[str,torch.Tensor]):32pass
This is a standard embeddings layer with code to load the checkpoint.
35classEmbedding(NeoXModule):
n_vocab is the size of the vocabularyn_hidden is the size of the embeddings42def\_\_init\_\_(self,n\_vocab:int=50\_432,n\_hidden:int=6\_144):
47super().\_\_init\_\_()4849self.emb=nn.Embedding(n\_vocab,n\_hidden)
x are the token ids of shape [batch_size, seq_len]51defforward(self,x:torch.Tensor):
55returnself.emb(x)
Code to load the checkpoint
57defload\_state(self,p1:Dict[str,torch.Tensor],p2:Dict[str,torch.Tensor]):
61withmonit.section('Load embedding layer'):62checkpoint.merge\_params\_dim\_0(self.emb.weight,'word\_embeddings.weight',p1,p2)
GPT-NeoX uses rotary positional embeddings (RoPE).
WE have annotated implementation of RoPE here with more notes the theory.
65classRoPE(nn.Module):
d_rope is the number of features for RoPE embeddingsbase is the base for θi=10000d2(i−1), which defaults to 1000075def\_\_init\_\_(self,d\_rope:int,base:float=10\_000.):
80super().\_\_init\_\_()
To store θi for the features
83self.theta=None
Cache cosmθi and sinmθi
85self.cos\_cached=None86self.sin\_cached=None
Base for θi=10000d2(i−1)
89self.base=base
Number of features for RoPE
91self.d\_rope=d\_rope
[−x(2d+1),−x(2d+2),...,−x(d),x(1),x(2),...,−x(2d)]
93@staticmethod94defrotate\_half(x:torch.Tensor):
100x1,x2=x[...,:x.shape[-1]//2],x[...,x.shape[-1]//2:]101returntorch.cat((-x2,x1),dim=-1)
x has shape [..., seq, n_heads, d_k]offset is the starting position of x . This is >0 when we have cached the keys and queries of previous positions103defforward(self,x:torch.Tensor,offset:int=0):
Get the actual sequence length
111seq\_len=x.shape[-3]+offset
Initialize θ
114ifself.thetaisNone:
θi=10000d2(i−1)
116theta=1.0/(self.base\*\*(torch.arange(0,self.d\_rope,2).float()/self.d\_rope))117self.theta=theta.to(x.device).to(x.dtype)
Initialize cosmθi and sinmθi cache
120if(121self.cos\_cachedisNoneor122seq\_len\>self.cos\_cached.shape[1]or123self.cos\_cached.device!=x.deviceor124self.cos\_cached.dtype!=x.dtype125):
Get position indexes m
127seq\_idx=torch.arange(seq\_len,device=x.device).type\_as(self.theta)
mθi
129idx\_theta=torch.einsum("s,d-\>sd",seq\_idx,self.theta)
Concatenate so that for row m we have
[mθ0,mθ1,...,mθ2d,mθ0,mθ1,...,mθ2d]
133idx\_theta2=torch.cat((idx\_theta,idx\_theta),dim=-1).to(x.device)
Calculate cosmθi and sinmθi in fp32
136withautocast(enabled=False):137idx\_theta2=idx\_theta2.float()
Add head dimension
139self.cos\_cached=idx\_theta2.cos()[:,None,:]140self.sin\_cached=idx\_theta2.sin()[:,None,:]
Cache them
143self.cos\_cached=self.cos\_cached.to(x.dtype)144self.sin\_cached=self.sin\_cached.to(x.dtype)
Split the features. We apply RoPE to only d_rope features
147x\_rope,x\_pass=x[...,:self.d\_rope],x[...,self.d\_rope:]
Get the sin and cos values from the cache
150cos,sin=self.cos\_cached[offset:seq\_len],self.sin\_cached[offset:seq\_len]
RoPE embeddings
(xm(i)cosmθi−xm(i+2d)sinmθixm(i+2d)cosmθi+xm(i)sinmθi)
for i∈1,2,...,2d
162x\_rope=(x\_rope\*cos)+(self.rotate\_half(x\_rope)\*sin)
Concatenate with features that didn't get RoPE embeddings
165returntorch.cat((x\_rope,x\_pass),dim=-1)
168classAttentionLayer(nn.Module):
n_hidden the number of features in embeddingsn_heads the number of attention headsrope_percentage percentage of features to add RoPE embeddingsmask_fill masking fill value for attention matrixis_flash_attention specifies whether to use FlashAttention173def\_\_init\_\_(self,n\_hidden:int=6\_144,n\_heads:int=64,rope\_percentage:float=0.25,174mask\_fill:float=-10\_000.0,\*,is\_flash\_attention:bool=False):
183super().\_\_init\_\_()184185self.n\_heads=n\_heads186self.mask\_fill=mask\_fill
Linear layer for query, key and value
189self.qkv\_lin=nn.Linear(n\_hidden,n\_hidden\*3)
Final linear layer
191self.output=nn.Linear(n\_hidden,n\_hidden)
Number of features per head
194d\_k=n\_hidden//n\_heads
RoPE embedding module
196self.rope=RoPE(int(d\_k\*rope\_percentage))
Attention scaling factor
199self.scale=1/math.sqrt(d\_k)
To cache causal mask
202self.causal\_mask=None
Attention softmax module
205self.softmax=nn.Softmax(dim=-2)
208ifis\_flash\_attention:209try:210fromflash\_attn.flash\_attentionimportFlashAttention211self.flash\_attention=FlashAttention()212exceptImportError:213logger.log('Install flash attention github.com/HazyResearch/flash-attention. '214'Falling back to normal attention',Text.warning)215self.flash\_attention=None216else:217self.flash\_attention=None
attn has shape batch_size, query_seq_len, key_seq_len, n_heads219def\_get\_mask(self,attn:torch.Tensor):
Query and key lengths
227nq,nk=attn.shape[1:3]
Create mask
230if(231self.causal\_maskisNoneor232self.causal\_mask.shape[0]!=nqor233self.causal\_mask.shape[1]!=nkor234self.causal\_mask.device!=attn.device235):236self.causal\_mask=torch.triu(attn.new\_ones([nq,nk],dtype=torch.bool),1+nk-nq)
Return from cache
239returnself.causal\_mask[None,:,:,None]
x has shape [batch_size, seq_len, n_hidden]241defforward(self,x:torch.Tensor):
Get query, key and value embeddings (all concatenated). The last dimension size will change from n_hidden -> 3 x n_hidden
247qkv=self.qkv\_lin(x)
Split into heads by changing the shape to [batch_size, seq_len, n_heads, 3 * d_k]
250qkv=qkv.view(\*qkv.shape[:-1],self.n\_heads,-1)
Split into query, key and value each of shape [batch_size, seq_len, n_heads, 3 * d_k]
252q,k,v=torch.split(qkv,qkv.shape[-1]//3,dim=-1)
If we are caching the states of previous tokens
255ifget\_cache().get('use\_cache',False):
Get the state id's. We use to retrieve previous states and store the next states
257prev\_state\_id,next\_state\_id=get\_cache().get('state\_ids')
If there's cache
259ifprev\_state\_idisnotNone:
Get the past keys and values. These will have shape [batch_size, prev_seq_len, n_heads, d_k]
261k\_past,v\_past=get\_cache().pop(f'attn\_kv\_{prev\_state\_id}')
Offset of the current embeddings
263offset=k\_past.shape[1]
Add RoPE embeddings
266q=self.rope(q,offset=offset)267k=self.rope(k,offset=offset)
Concatenate the past
270k=torch.cat([k\_past,k],dim=1)271v=torch.cat([v\_past,v],dim=1)272else:
Add RoPE embeddings
274q=self.rope(q)275k=self.rope(k)
Save the current state
278get\_cache().push(f'attn\_kv\_{next\_state\_id}',(k,v))279else:
No cache - simply add RoPE embeddings
281q=self.rope(q)282k=self.rope(k)
Use flash attention
285ifself.flash\_attentionisnotNoneandq.shape[1]==k.shape[1]andq.shape[-1]\<=128:286output=self.compute\_flash\_attention(q,k,v)
Otherwise, use normal attention
288else:289output=self.compute\_attention(q,k,v)
Reshape from [batch_size, seq_len, n_heads, d_k] tobatch_size, seq_len, n_hidden`
292output=output.reshape(\*x.shape)
Final linear layer
295returnself.output(output)
297defcompute\_flash\_attention(self,q:torch.Tensor,k:torch.Tensor,v:torch.Tensor):
Stack them into shape [batch_size, seq_len, 3, n_heads, d_k]
299qkv=torch.stack((q,k,v),dim=2)300d\_k=qkv.shape[-1]301ifd\_k\<=32:302pad=32-d\_k303elifd\_k\<=64:304pad=64-d\_k305elifd\_k\<=128:306pad=128-d\_k307else:308raiseValueError(f'Head size {d\_k} too large for flash attention')309310ifpad\>0:311qkv=torch.cat((qkv,qkv.new\_zeros(\*qkv.shape[:-1],pad)),dim=-1)312313output,\_=self.flash\_attention(qkv,causal=True)
The output is of shape [batch_size, seq_len, n_heads, d_k + padding]
315output=output[:,:,:,:d\_k]316317returnoutput
319defcompute\_attention(self,q:torch.Tensor,k:torch.Tensor,v:torch.Tensor):
Disable auto-casting to fp16 for attention computation
321withautocast(enabled=False):322ifq.dtype==torch.float16:
Convert to fp32 if the current dtype is fp16
324attn=torch.einsum('bihk,bjhk-\>bijh',q.float(),k.float())325else:
Do not cast for bfloat
327attn=torch.einsum('bihk,bjhk-\>bijh',q,k)
Scale attention
330attn=attn\*self.scale
Get causal mask
333mask=self.\_get\_mask(attn)
Apply mask
335attn.masked\_fill\_(mask,self.mask\_fill)
Attention softmax
338attn=self.softmax(attn)
Get attention weighted values
341output=torch.einsum('bijh,bjhk-\>bihk',attn.to(v.dtype),v)342343returnoutput
346classFFNLayer(nn.Module):
n_hidden is the embedding size351def\_\_init\_\_(self,n\_hidden:int=6\_144,d\_ff:int=0):
355super().\_\_init\_\_()356357ifnotd\_ff:358d\_ff=n\_hidden\*4
Expansion linear layer
361self.dense\_h\_h4=nn.Linear(n\_hidden,d\_ff)
GELU activation
363self.activation=nn.GELU()
Contraction linear layer
365self.dense\_h4\_h=nn.Linear(d\_ff,n\_hidden)
x has shape [batch_size, seq_len, n_hidden]367defforward(self,x:torch.Tensor):
371x=self.dense\_h\_h4(x)372x=self.activation(x)373x=self.dense\_h4\_h(x)374375returnx
378classTransformerLayer(NeoXModule):
n_hidden is the embedding sizen_heads is the number of headsis_flash_attention specifies whether to use FlashAttentionOut implementation doesn't include dropout.
383def\_\_init\_\_(self,n\_hidden:int=6\_144,n\_heads:int=64,\*,is\_flash\_attention:bool=False):
392super().\_\_init\_\_()
Layer normalization before attention
395self.pre\_ln\_attn=nn.LayerNorm(n\_hidden)
Layer normalization before FFN
397self.pre\_ln\_ffn=nn.LayerNorm(n\_hidden)
Attention layer
400self.attention=AttentionLayer(n\_hidden,n\_heads,is\_flash\_attention=is\_flash\_attention)
FFN layer
402self.ffn=FFNLayer(n\_hidden)
x are the embeddings of shape [batch_size, seq_len, n_hidden]404defforward(self,x:torch.Tensor):
Residual connection
410residual=x
NeoX runs attention and feedforward network in parallel
412attn=self.attention(self.pre\_ln\_attn(x))413ffn=self.ffn(self.pre\_ln\_ffn(x))
Add them and the residual connection
415returnattn+ffn+residual
Code to load the checkpoint
417defload\_state(self,p1:Dict[str,torch.Tensor],p2:Dict[str,torch.Tensor]):
421withmonit.section('Load transformer layer'):
Attention output transform
423checkpoint.merge\_params\_sum(self.attention.output.bias,'attention.dense.bias',p1,p2)424checkpoint.merge\_params\_dim\_1(self.attention.output.weight,'attention.dense.weight',p1,p2)
Attention query, key and value transform
427checkpoint.merge\_params\_dim\_0(self.attention.qkv\_lin.bias,'attention.query\_key\_value.bias',p1,p2)428checkpoint.merge\_params\_dim\_0(self.attention.qkv\_lin.weight,'attention.query\_key\_value.weight',p1,p2)
Layer norm before attention
431checkpoint.merge\_params\_duplicate(self.pre\_ln\_attn.bias,'input\_layernorm.bias',p1,p2)432checkpoint.merge\_params\_duplicate(self.pre\_ln\_attn.weight,'input\_layernorm.weight',p1,p2)
FFN second transform
435checkpoint.merge\_params\_dim\_0(self.ffn.dense\_h\_h4.bias,'mlp.dense\_h\_to\_4h.bias',p1,p2)436checkpoint.merge\_params\_dim\_0(self.ffn.dense\_h\_h4.weight,'mlp.dense\_h\_to\_4h.weight',p1,p2)
FFN first transform
439checkpoint.merge\_params\_sum(self.ffn.dense\_h4\_h.bias,'mlp.dense\_4h\_to\_h.bias',p1,p2)440checkpoint.merge\_params\_dim\_1(self.ffn.dense\_h4\_h.weight,'mlp.dense\_4h\_to\_h.weight',p1,p2)
Layer norm before FFN
443checkpoint.merge\_params\_duplicate(self.pre\_ln\_ffn.bias,'post\_attention\_layernorm.bias',p1,p2)444checkpoint.merge\_params\_duplicate(self.pre\_ln\_ffn.weight,'post\_attention\_layernorm.weight',p1,p2)
447classFinalNorm(NeoXModule):
n_hidden is the embedding size452def\_\_init\_\_(self,n\_hidden:int=6\_144):
456super().\_\_init\_\_()457458self.ln=nn.LayerNorm(n\_hidden)
x are the embeddings of shape [batch_size, seq_len, n_hidden]460defforward(self,x:torch.Tensor):
464returnself.ln(x)
Code to load the checkpoint
466defload\_state(self,p1:Dict[str,torch.Tensor],p2:Dict[str,torch.Tensor]):
470withmonit.section('Load final normalization layer'):471checkpoint.merge\_params\_duplicate(self.ln.bias,'norm.bias',p1,p2)472checkpoint.merge\_params\_duplicate(self.ln.weight,'norm.weight',p1,p2)
Readout layer
475classReadoutLayer(NeoXModule):
n_hidden is the embedding sizen_vocab is the size of the vocabulary480def\_\_init\_\_(self,n\_hidden:int=6\_144,n\_vocab:int=50\_432):
485super().\_\_init\_\_()486487self.linear=nn.Linear(n\_hidden,n\_vocab,bias=False)
x are the embeddings of shape [batch_size, seq_len, n_hidden]489defforward(self,x:torch.Tensor):
493returnself.linear(x)
Code to load the checkpoint
495defload\_state(self,p1:Dict[str,torch.Tensor],p2:Dict[str,torch.Tensor]):
499withmonit.section('Load final linear layer'):500checkpoint.merge\_params\_dim\_0(self.linear.weight,'final\_linear.weight',p1,p2)
503classLayerGenerator:504pre\_created\_layers:Dict[Any,Optional[NeoXModule]]
The layers are generated in the same order as checkpoints.
It gives None when a layer is not available; we use the layer indices as NeoX and there are two transformation layers we don't need in our implementation.
n_vocab is the number of tokens in the vocabularyn_hidden is the number of features in the embeddingsn_layers is the number of transformer layersn_heads is the number of attention headsfilter_layers are the set of layers to be used. All layers will be used if None. This is used to test smaller versions of the model with fewer layersis_clone_layers specifies whether to clone the transformer layers (a bit faster)dtype is the data type of the modeldevice is the device of the modelis_llm_int8 specifies whether to use int8 quantizationllm_int8_threshold is the threshold α used to separate outlier featuresis_flash_attention specifies whether to use FlashAttention506def\_\_init\_\_(self,\*,n\_vocab:int=50\_432,n\_hidden:int=6\_144,507n\_layers:int=44,n\_heads:int=64,508filter\_layers:Optional[Set]=None,509is\_clone\_layers:bool=True,510dtype:torch.dtype=torch.float,511device:torch.device=torch.device('cpu'),512is\_llm\_int8:bool=False,513llm\_int8\_threshold:float=6.0,514is\_flash\_attention:bool=False515):
538iffilter\_layersisNone:539filter\_layers=set(range(n\_layers+3))540541self.n\_vocab=n\_vocab542self.n\_hidden=n\_hidden543self.n\_layers=n\_layers544self.n\_heads=n\_heads545self.filter\_layers=filter\_layers546self.is\_clone\_layers=is\_clone\_layers547self.dtype=dtype548self.device=device549self.is\_llm\_int8=is\_llm\_int8550self.llm\_int8\_threshold=llm\_int8\_threshold551self.is\_flash\_attention=is\_flash\_attention552553self.pre\_created\_layers=dict(554transformer\_layer=None,555)
We move the layer to the device and convert it to the correct data type
layer is the layer to prepareReturns the prepared layer
557def\_prepare\_layer(self,layer:NeoXModule):
566returnlayer.to(self.device,self.dtype)
This function implements layer transformations after loading the checkpoint.
Currently, it only applies the int8 quantization.
layer is the layer to prepareis_llm_int8 specifies whether to use int8 quantizationdevice is the device of the modelllm_int8_threshold is the threshold α used to separate outlier featuresReturns the prepared layer
[email protected]\_grad()569defpost\_load\_prepare(self,layer:NeoXModule,\*,570is\_llm\_int8:bool=None,571device:torch.device=None,572llm\_int8\_threshold:float=None,573):
Get default values if not specified
591ifis\_llm\_int8isNone:592is\_llm\_int8=self.is\_llm\_int8593ifdeviceisNone:594device=self.device595ifllm\_int8\_thresholdisNone:596llm\_int8\_threshold=self.llm\_int8\_threshold
Skip if not using int8 quantization
599ifnotis\_llm\_int8:600returnlayer
Only convert the linear layers in the transformer layers
603ifnotisinstance(layer,TransformerLayer):604returnlayer
Use make_llm_int8_linear defined in utilities.
607fromlabml\_nn.neox.utils.llm\_int8importmake\_llm\_int8\_linear
Convert the linear layers
610withmonit.section('Convert to int8'):611layer.attention.output=make\_llm\_int8\_linear(layer.attention.output,612device=device,613threshold=llm\_int8\_threshold)614layer.attention.qkv\_lin=make\_llm\_int8\_linear(layer.attention.qkv\_lin,615device=device,616threshold=llm\_int8\_threshold)617layer.ffn.dense\_h\_h4=make\_llm\_int8\_linear(layer.ffn.dense\_h\_h4,618device=device,619threshold=llm\_int8\_threshold)620layer.ffn.dense\_h4\_h=make\_llm\_int8\_linear(layer.ffn.dense\_h4\_h,621device=device,622threshold=llm\_int8\_threshold)
624returnlayer
Copying cached layers is faster than initializing new layers because it takes time to initialize parameters.
name is the name of the layercreator is the function to create the layerReturns the created layer or a copy of the cached layer
626def\_create\_and\_cache\_layer(self,name:str,creator:Callable[[],NeoXModule]):
638ifnotself.is\_clone\_layers:639returnself.\_prepare\_layer(creator())640641ifself.pre\_created\_layers[name]isNone:642self.pre\_created\_layers[name]=self.\_prepare\_layer(creator())643644layer=copy.deepcopy(self.pre\_created\_layers[name])645returnlayer
647def\_create\_transformer\_layer(self):648returnself.\_create\_and\_cache\_layer(649'transformer\_layer',650lambda:TransformerLayer(self.n\_hidden,self.n\_heads,is\_flash\_attention=self.is\_flash\_attention)651)
653def\_create\_embedding\_layer(self):654returnEmbedding(self.n\_vocab,self.n\_hidden)
656def\_create\_final\_norm\_layer(self):657returnFinalNorm(self.n\_hidden)
659def\_create\_readout\_layer(self):660returnReadoutLayer(self.n\_hidden,self.n\_vocab)
[email protected]\_grad()663defget\_layers(self)-\>Generator[Tuple[NeoXModule,Tuple[str,str]],None,None]:
Embedding layer
668if0inself.filter\_layers:669withmonit.section('Embedding layer'):670layer=self.\_prepare\_layer(self.\_create\_embedding\_layer())671yieldlayer,('layer\_00-model\_00-model\_states.pt','layer\_00-model\_01-model\_states.pt')
Transformer layers
674foriinrange(self.n\_layers):
Transformer layer
676ifi+1inself.filter\_layers:677withmonit.section(f'Transformer Layer {i}'):678yieldself.\_create\_transformer\_layer(),\679(f'layer\_{i + 2 :02d}-model\_00-model\_states.pt',680f'layer\_{i + 2 :02d}-model\_01-model\_states.pt')
Final normalization layer
683ifself.n\_layers+1inself.filter\_layers:684withmonit.section('Final norm layer'):685layer=self.\_prepare\_layer(self.\_create\_final\_norm\_layer())686yieldlayer,('layer\_47-model\_00-model\_states.pt','layer\_47-model\_01-model\_states.pt')
Readout layer
689ifself.n\_layers+2inself.filter\_layers:690withmonit.section('Readout layer'):691layer=self.\_prepare\_layer(self.\_create\_readout\_layer())692yieldlayer,('layer\_48-model\_00-model\_states.pt','layer\_48-model\_01-model\_states.pt')693694forkinself.pre\_created\_layers.keys():695self.pre\_created\_layers[k]=None
697@property698deftotal\_layers(self):
702returnself.n\_layers+3
[email protected]\_grad()705defload(self)-\>Generator[NeoXModule,None,None]:
709withmonit.section("Layers"):710fori,(layer,files)inenumerate(self.get\_layers()):711iffilesisnotNone:712layer.load\_state(\*checkpoint.load\_checkpoint\_files(files))713714layer=self.post\_load\_prepare(layer)715716monit.progress(min(0.99,(i+1)/self.total\_layers))717yieldlayer