Back to Annotated Deep Learning Paper Implementations

GPT-NeoX Model

docs/neox/model.html

latest24.3 KB
Original Source

homeneox

View code on Github

#

GPT-NeoX Model

Here is the code for layers of GPT-NeoX model and the code to load 20B checkpoint.

The method load_state in the layers load the checkpoints of that layer. The checkpoint loading helpers are on checkpoint.py

16importcopy17importmath18fromtypingimportDict,Optional,Set,Callable,Any,Generator,Tuple1920importtorch21fromtorchimportnn22fromtorch.cuda.ampimportautocast2324fromlabmlimportmonit,logger25fromlabml.loggerimportText26fromlabml\_nn.neoximportcheckpoint27fromlabml\_nn.neox.utils.cacheimportget\_cache

#

30classNeoXModule(nn.Module):

#

31defload\_state(self,p1:Dict[str,torch.Tensor],p2:Dict[str,torch.Tensor]):32pass

#

Embedding layer

This is a standard embeddings layer with code to load the checkpoint.

35classEmbedding(NeoXModule):

#

  • n_vocab is the size of the vocabulary
  • n_hidden is the size of the embeddings
42def\_\_init\_\_(self,n\_vocab:int=50\_432,n\_hidden:int=6\_144):

#

47super().\_\_init\_\_()4849self.emb=nn.Embedding(n\_vocab,n\_hidden)

#

  • x are the token ids of shape [batch_size, seq_len]
51defforward(self,x:torch.Tensor):

#

55returnself.emb(x)

#

Code to load the checkpoint

57defload\_state(self,p1:Dict[str,torch.Tensor],p2:Dict[str,torch.Tensor]):

#

61withmonit.section('Load embedding layer'):62checkpoint.merge\_params\_dim\_0(self.emb.weight,'word\_embeddings.weight',p1,p2)

#

Rotary Positional Embeddings

GPT-NeoX uses rotary positional embeddings (RoPE).

WE have annotated implementation of RoPE here with more notes the theory.

65classRoPE(nn.Module):

#

  • d_rope is the number of features for RoPE embeddings
  • base is the base for θi​=10000d2(i−1)​, which defaults to 10000
75def\_\_init\_\_(self,d\_rope:int,base:float=10\_000.):

#

80super().\_\_init\_\_()

#

To store θi​ for the features

83self.theta=None

#

Cache cosmθi​ and sinmθi​

85self.cos\_cached=None86self.sin\_cached=None

#

Base for θi​=10000d2(i−1)​

89self.base=base

#

Number of features for RoPE

91self.d\_rope=d\_rope

#

Rotate the features

[−x(2d​+1),−x(2d​+2),...,−x(d),x(1),x(2),...,−x(2d​)]

93@staticmethod94defrotate\_half(x:torch.Tensor):

#

100x1,x2=x[...,:x.shape[-1]//2],x[...,x.shape[-1]//2:]101returntorch.cat((-x2,x1),dim=-1)

#

  • x has shape [..., seq, n_heads, d_k]
  • offset is the starting position of x . This is >0 when we have cached the keys and queries of previous positions
103defforward(self,x:torch.Tensor,offset:int=0):

#

Get the actual sequence length

111seq\_len=x.shape[-3]+offset

#

Initialize θ

114ifself.thetaisNone:

#

θi​=10000d2(i−1)​

116theta=1.0/(self.base\*\*(torch.arange(0,self.d\_rope,2).float()/self.d\_rope))117self.theta=theta.to(x.device).to(x.dtype)

#

Initialize cosmθi​ and sinmθi​ cache

120if(121self.cos\_cachedisNoneor122seq\_len\>self.cos\_cached.shape[1]or123self.cos\_cached.device!=x.deviceor124self.cos\_cached.dtype!=x.dtype125):

#

Get position indexes m

127seq\_idx=torch.arange(seq\_len,device=x.device).type\_as(self.theta)

#

mθi​

129idx\_theta=torch.einsum("s,d-\>sd",seq\_idx,self.theta)

#

Concatenate so that for row m we have

[mθ0​,mθ1​,...,mθ2d​​,mθ0​,mθ1​,...,mθ2d​​]

133idx\_theta2=torch.cat((idx\_theta,idx\_theta),dim=-1).to(x.device)

#

Calculate cosmθi​ and sinmθi​ in fp32

136withautocast(enabled=False):137idx\_theta2=idx\_theta2.float()

#

Add head dimension

139self.cos\_cached=idx\_theta2.cos()[:,None,:]140self.sin\_cached=idx\_theta2.sin()[:,None,:]

#

Cache them

143self.cos\_cached=self.cos\_cached.to(x.dtype)144self.sin\_cached=self.sin\_cached.to(x.dtype)

#

Split the features. We apply RoPE to only d_rope features

147x\_rope,x\_pass=x[...,:self.d\_rope],x[...,self.d\_rope:]

#

Get the sin and cos values from the cache

150cos,sin=self.cos\_cached[offset:seq\_len],self.sin\_cached[offset:seq\_len]

#

RoPE embeddings

(xm(i)​cosmθi​−xm(i+2d​)​sinmθi​xm(i+2d​)​cosmθi​+xm(i)​sinmθi​​)​

for i∈1,2,...,2d​

162x\_rope=(x\_rope\*cos)+(self.rotate\_half(x\_rope)\*sin)

#

Concatenate with features that didn't get RoPE embeddings

165returntorch.cat((x\_rope,x\_pass),dim=-1)

#

Attention layer

168classAttentionLayer(nn.Module):

#

  • n_hidden the number of features in embeddings
  • n_heads the number of attention heads
  • rope_percentage percentage of features to add RoPE embeddings
  • mask_fill masking fill value for attention matrix
  • is_flash_attention specifies whether to use FlashAttention
173def\_\_init\_\_(self,n\_hidden:int=6\_144,n\_heads:int=64,rope\_percentage:float=0.25,174mask\_fill:float=-10\_000.0,\*,is\_flash\_attention:bool=False):

#

183super().\_\_init\_\_()184185self.n\_heads=n\_heads186self.mask\_fill=mask\_fill

#

Linear layer for query, key and value

189self.qkv\_lin=nn.Linear(n\_hidden,n\_hidden\*3)

#

Final linear layer

191self.output=nn.Linear(n\_hidden,n\_hidden)

#

Number of features per head

194d\_k=n\_hidden//n\_heads

#

RoPE embedding module

196self.rope=RoPE(int(d\_k\*rope\_percentage))

#

Attention scaling factor

199self.scale=1/math.sqrt(d\_k)

#

To cache causal mask

202self.causal\_mask=None

#

Attention softmax module

205self.softmax=nn.Softmax(dim=-2)

#

FlashAttention

208ifis\_flash\_attention:209try:210fromflash\_attn.flash\_attentionimportFlashAttention211self.flash\_attention=FlashAttention()212exceptImportError:213logger.log('Install flash attention github.com/HazyResearch/flash-attention. '214'Falling back to normal attention',Text.warning)215self.flash\_attention=None216else:217self.flash\_attention=None

#

Calculate the causal mask

219def\_get\_mask(self,attn:torch.Tensor):

#

Query and key lengths

227nq,nk=attn.shape[1:3]

#

Create mask

230if(231self.causal\_maskisNoneor232self.causal\_mask.shape[0]!=nqor233self.causal\_mask.shape[1]!=nkor234self.causal\_mask.device!=attn.device235):236self.causal\_mask=torch.triu(attn.new\_ones([nq,nk],dtype=torch.bool),1+nk-nq)

#

Return from cache

239returnself.causal\_mask[None,:,:,None]

#

  • x has shape [batch_size, seq_len, n_hidden]
241defforward(self,x:torch.Tensor):

#

Get query, key and value embeddings (all concatenated). The last dimension size will change from n_hidden -> 3 x n_hidden

247qkv=self.qkv\_lin(x)

#

Split into heads by changing the shape to [batch_size, seq_len, n_heads, 3 * d_k]

250qkv=qkv.view(\*qkv.shape[:-1],self.n\_heads,-1)

#

Split into query, key and value each of shape [batch_size, seq_len, n_heads, 3 * d_k]

252q,k,v=torch.split(qkv,qkv.shape[-1]//3,dim=-1)

#

If we are caching the states of previous tokens

255ifget\_cache().get('use\_cache',False):

#

Get the state id's. We use to retrieve previous states and store the next states

257prev\_state\_id,next\_state\_id=get\_cache().get('state\_ids')

#

If there's cache

259ifprev\_state\_idisnotNone:

#

Get the past keys and values. These will have shape [batch_size, prev_seq_len, n_heads, d_k]

261k\_past,v\_past=get\_cache().pop(f'attn\_kv\_{prev\_state\_id}')

#

Offset of the current embeddings

263offset=k\_past.shape[1]

#

Add RoPE embeddings

266q=self.rope(q,offset=offset)267k=self.rope(k,offset=offset)

#

Concatenate the past

270k=torch.cat([k\_past,k],dim=1)271v=torch.cat([v\_past,v],dim=1)272else:

#

Add RoPE embeddings

274q=self.rope(q)275k=self.rope(k)

#

Save the current state

278get\_cache().push(f'attn\_kv\_{next\_state\_id}',(k,v))279else:

#

No cache - simply add RoPE embeddings

281q=self.rope(q)282k=self.rope(k)

#

Use flash attention

285ifself.flash\_attentionisnotNoneandq.shape[1]==k.shape[1]andq.shape[-1]\<=128:286output=self.compute\_flash\_attention(q,k,v)

#

Otherwise, use normal attention

288else:289output=self.compute\_attention(q,k,v)

#

Reshape from [batch_size, seq_len, n_heads, d_k] tobatch_size, seq_len, n_hidden`

292output=output.reshape(\*x.shape)

#

Final linear layer

295returnself.output(output)

#

297defcompute\_flash\_attention(self,q:torch.Tensor,k:torch.Tensor,v:torch.Tensor):

#

Stack them into shape [batch_size, seq_len, 3, n_heads, d_k]

299qkv=torch.stack((q,k,v),dim=2)300d\_k=qkv.shape[-1]301ifd\_k\<=32:302pad=32-d\_k303elifd\_k\<=64:304pad=64-d\_k305elifd\_k\<=128:306pad=128-d\_k307else:308raiseValueError(f'Head size {d\_k} too large for flash attention')309310ifpad\>0:311qkv=torch.cat((qkv,qkv.new\_zeros(\*qkv.shape[:-1],pad)),dim=-1)312313output,\_=self.flash\_attention(qkv,causal=True)

#

The output is of shape [batch_size, seq_len, n_heads, d_k + padding]

315output=output[:,:,:,:d\_k]316317returnoutput

#

319defcompute\_attention(self,q:torch.Tensor,k:torch.Tensor,v:torch.Tensor):

#

Disable auto-casting to fp16 for attention computation

321withautocast(enabled=False):322ifq.dtype==torch.float16:

#

Convert to fp32 if the current dtype is fp16

324attn=torch.einsum('bihk,bjhk-\>bijh',q.float(),k.float())325else:

#

Do not cast for bfloat

327attn=torch.einsum('bihk,bjhk-\>bijh',q,k)

#

Scale attention

330attn=attn\*self.scale

#

Get causal mask

333mask=self.\_get\_mask(attn)

#

Apply mask

335attn.masked\_fill\_(mask,self.mask\_fill)

#

Attention softmax

338attn=self.softmax(attn)

#

Get attention weighted values

341output=torch.einsum('bijh,bjhk-\>bihk',attn.to(v.dtype),v)342343returnoutput

#

Feedforward Network

346classFFNLayer(nn.Module):

#

  • n_hidden is the embedding size
351def\_\_init\_\_(self,n\_hidden:int=6\_144,d\_ff:int=0):

#

355super().\_\_init\_\_()356357ifnotd\_ff:358d\_ff=n\_hidden\*4

#

Expansion linear layer

361self.dense\_h\_h4=nn.Linear(n\_hidden,d\_ff)

#

GELU activation

363self.activation=nn.GELU()

#

Contraction linear layer

365self.dense\_h4\_h=nn.Linear(d\_ff,n\_hidden)

#

  • x has shape [batch_size, seq_len, n_hidden]
367defforward(self,x:torch.Tensor):

#

371x=self.dense\_h\_h4(x)372x=self.activation(x)373x=self.dense\_h4\_h(x)374375returnx

#

Transformer Layer

378classTransformerLayer(NeoXModule):

#

  • n_hidden is the embedding size
  • n_heads is the number of heads
  • is_flash_attention specifies whether to use FlashAttention

Out implementation doesn't include dropout.

383def\_\_init\_\_(self,n\_hidden:int=6\_144,n\_heads:int=64,\*,is\_flash\_attention:bool=False):

#

392super().\_\_init\_\_()

#

Layer normalization before attention

395self.pre\_ln\_attn=nn.LayerNorm(n\_hidden)

#

Layer normalization before FFN

397self.pre\_ln\_ffn=nn.LayerNorm(n\_hidden)

#

Attention layer

400self.attention=AttentionLayer(n\_hidden,n\_heads,is\_flash\_attention=is\_flash\_attention)

#

FFN layer

402self.ffn=FFNLayer(n\_hidden)

#

  • x are the embeddings of shape [batch_size, seq_len, n_hidden]
404defforward(self,x:torch.Tensor):

#

Residual connection

410residual=x

#

NeoX runs attention and feedforward network in parallel

412attn=self.attention(self.pre\_ln\_attn(x))413ffn=self.ffn(self.pre\_ln\_ffn(x))

#

Add them and the residual connection

415returnattn+ffn+residual

#

Code to load the checkpoint

417defload\_state(self,p1:Dict[str,torch.Tensor],p2:Dict[str,torch.Tensor]):

#

421withmonit.section('Load transformer layer'):

#

Attention output transform

423checkpoint.merge\_params\_sum(self.attention.output.bias,'attention.dense.bias',p1,p2)424checkpoint.merge\_params\_dim\_1(self.attention.output.weight,'attention.dense.weight',p1,p2)

#

Attention query, key and value transform

427checkpoint.merge\_params\_dim\_0(self.attention.qkv\_lin.bias,'attention.query\_key\_value.bias',p1,p2)428checkpoint.merge\_params\_dim\_0(self.attention.qkv\_lin.weight,'attention.query\_key\_value.weight',p1,p2)

#

Layer norm before attention

431checkpoint.merge\_params\_duplicate(self.pre\_ln\_attn.bias,'input\_layernorm.bias',p1,p2)432checkpoint.merge\_params\_duplicate(self.pre\_ln\_attn.weight,'input\_layernorm.weight',p1,p2)

#

FFN second transform

435checkpoint.merge\_params\_dim\_0(self.ffn.dense\_h\_h4.bias,'mlp.dense\_h\_to\_4h.bias',p1,p2)436checkpoint.merge\_params\_dim\_0(self.ffn.dense\_h\_h4.weight,'mlp.dense\_h\_to\_4h.weight',p1,p2)

#

FFN first transform

439checkpoint.merge\_params\_sum(self.ffn.dense\_h4\_h.bias,'mlp.dense\_4h\_to\_h.bias',p1,p2)440checkpoint.merge\_params\_dim\_1(self.ffn.dense\_h4\_h.weight,'mlp.dense\_4h\_to\_h.weight',p1,p2)

#

Layer norm before FFN

443checkpoint.merge\_params\_duplicate(self.pre\_ln\_ffn.bias,'post\_attention\_layernorm.bias',p1,p2)444checkpoint.merge\_params\_duplicate(self.pre\_ln\_ffn.weight,'post\_attention\_layernorm.weight',p1,p2)

#

Final normalization layer

447classFinalNorm(NeoXModule):

#

  • n_hidden is the embedding size
452def\_\_init\_\_(self,n\_hidden:int=6\_144):

#

456super().\_\_init\_\_()457458self.ln=nn.LayerNorm(n\_hidden)

#

  • x are the embeddings of shape [batch_size, seq_len, n_hidden]
460defforward(self,x:torch.Tensor):

#

464returnself.ln(x)

#

Code to load the checkpoint

466defload\_state(self,p1:Dict[str,torch.Tensor],p2:Dict[str,torch.Tensor]):

#

470withmonit.section('Load final normalization layer'):471checkpoint.merge\_params\_duplicate(self.ln.bias,'norm.bias',p1,p2)472checkpoint.merge\_params\_duplicate(self.ln.weight,'norm.weight',p1,p2)

#

Readout layer

475classReadoutLayer(NeoXModule):

#

  • n_hidden is the embedding size
  • n_vocab is the size of the vocabulary
480def\_\_init\_\_(self,n\_hidden:int=6\_144,n\_vocab:int=50\_432):

#

485super().\_\_init\_\_()486487self.linear=nn.Linear(n\_hidden,n\_vocab,bias=False)

#

  • x are the embeddings of shape [batch_size, seq_len, n_hidden]
489defforward(self,x:torch.Tensor):

#

493returnself.linear(x)

#

Code to load the checkpoint

495defload\_state(self,p1:Dict[str,torch.Tensor],p2:Dict[str,torch.Tensor]):

#

499withmonit.section('Load final linear layer'):500checkpoint.merge\_params\_dim\_0(self.linear.weight,'final\_linear.weight',p1,p2)

#

503classLayerGenerator:504pre\_created\_layers:Dict[Any,Optional[NeoXModule]]

#

Generator to create layers

The layers are generated in the same order as checkpoints.

It gives None when a layer is not available; we use the layer indices as NeoX and there are two transformation layers we don't need in our implementation.

  • n_vocab is the number of tokens in the vocabulary
  • n_hidden is the number of features in the embeddings
  • n_layers is the number of transformer layers
  • n_heads is the number of attention heads
  • filter_layers are the set of layers to be used. All layers will be used if None. This is used to test smaller versions of the model with fewer layers
  • is_clone_layers specifies whether to clone the transformer layers (a bit faster)
  • dtype is the data type of the model
  • device is the device of the model
  • is_llm_int8 specifies whether to use int8 quantization
  • llm_int8_threshold is the threshold α used to separate outlier features
  • is_flash_attention specifies whether to use FlashAttention
506def\_\_init\_\_(self,\*,n\_vocab:int=50\_432,n\_hidden:int=6\_144,507n\_layers:int=44,n\_heads:int=64,508filter\_layers:Optional[Set]=None,509is\_clone\_layers:bool=True,510dtype:torch.dtype=torch.float,511device:torch.device=torch.device('cpu'),512is\_llm\_int8:bool=False,513llm\_int8\_threshold:float=6.0,514is\_flash\_attention:bool=False515):

#

538iffilter\_layersisNone:539filter\_layers=set(range(n\_layers+3))540541self.n\_vocab=n\_vocab542self.n\_hidden=n\_hidden543self.n\_layers=n\_layers544self.n\_heads=n\_heads545self.filter\_layers=filter\_layers546self.is\_clone\_layers=is\_clone\_layers547self.dtype=dtype548self.device=device549self.is\_llm\_int8=is\_llm\_int8550self.llm\_int8\_threshold=llm\_int8\_threshold551self.is\_flash\_attention=is\_flash\_attention552553self.pre\_created\_layers=dict(554transformer\_layer=None,555)

#

Prepares the layer for usage

We move the layer to the device and convert it to the correct data type

  • layer is the layer to prepare

Returns the prepared layer

557def\_prepare\_layer(self,layer:NeoXModule):

#

566returnlayer.to(self.device,self.dtype)

#

Layer transformations after loading the checkpoint

This function implements layer transformations after loading the checkpoint.

Currently, it only applies the int8 quantization.

  • layer is the layer to prepare
  • is_llm_int8 specifies whether to use int8 quantization
  • device is the device of the model
  • llm_int8_threshold is the threshold α used to separate outlier features

Returns the prepared layer

[email protected]\_grad()569defpost\_load\_prepare(self,layer:NeoXModule,\*,570is\_llm\_int8:bool=None,571device:torch.device=None,572llm\_int8\_threshold:float=None,573):

#

Get default values if not specified

591ifis\_llm\_int8isNone:592is\_llm\_int8=self.is\_llm\_int8593ifdeviceisNone:594device=self.device595ifllm\_int8\_thresholdisNone:596llm\_int8\_threshold=self.llm\_int8\_threshold

#

Skip if not using int8 quantization

599ifnotis\_llm\_int8:600returnlayer

#

Only convert the linear layers in the transformer layers

603ifnotisinstance(layer,TransformerLayer):604returnlayer

#

Use make_llm_int8_linear defined in utilities.

607fromlabml\_nn.neox.utils.llm\_int8importmake\_llm\_int8\_linear

#

Convert the linear layers

610withmonit.section('Convert to int8'):611layer.attention.output=make\_llm\_int8\_linear(layer.attention.output,612device=device,613threshold=llm\_int8\_threshold)614layer.attention.qkv\_lin=make\_llm\_int8\_linear(layer.attention.qkv\_lin,615device=device,616threshold=llm\_int8\_threshold)617layer.ffn.dense\_h\_h4=make\_llm\_int8\_linear(layer.ffn.dense\_h\_h4,618device=device,619threshold=llm\_int8\_threshold)620layer.ffn.dense\_h4\_h=make\_llm\_int8\_linear(layer.ffn.dense\_h4\_h,621device=device,622threshold=llm\_int8\_threshold)

#

624returnlayer

#

Creates and caches a layer

Copying cached layers is faster than initializing new layers because it takes time to initialize parameters.

  • name is the name of the layer
  • creator is the function to create the layer

Returns the created layer or a copy of the cached layer

626def\_create\_and\_cache\_layer(self,name:str,creator:Callable[[],NeoXModule]):

#

638ifnotself.is\_clone\_layers:639returnself.\_prepare\_layer(creator())640641ifself.pre\_created\_layers[name]isNone:642self.pre\_created\_layers[name]=self.\_prepare\_layer(creator())643644layer=copy.deepcopy(self.pre\_created\_layers[name])645returnlayer

#

647def\_create\_transformer\_layer(self):648returnself.\_create\_and\_cache\_layer(649'transformer\_layer',650lambda:TransformerLayer(self.n\_hidden,self.n\_heads,is\_flash\_attention=self.is\_flash\_attention)651)

#

653def\_create\_embedding\_layer(self):654returnEmbedding(self.n\_vocab,self.n\_hidden)

#

656def\_create\_final\_norm\_layer(self):657returnFinalNorm(self.n\_hidden)

#

659def\_create\_readout\_layer(self):660returnReadoutLayer(self.n\_hidden,self.n\_vocab)

#

Generator to get layers

[email protected]\_grad()663defget\_layers(self)-\>Generator[Tuple[NeoXModule,Tuple[str,str]],None,None]:

#

Embedding layer

668if0inself.filter\_layers:669withmonit.section('Embedding layer'):670layer=self.\_prepare\_layer(self.\_create\_embedding\_layer())671yieldlayer,('layer\_00-model\_00-model\_states.pt','layer\_00-model\_01-model\_states.pt')

#

Transformer layers

674foriinrange(self.n\_layers):

#

Transformer layer

676ifi+1inself.filter\_layers:677withmonit.section(f'Transformer Layer {i}'):678yieldself.\_create\_transformer\_layer(),\679(f'layer\_{i + 2 :02d}-model\_00-model\_states.pt',680f'layer\_{i + 2 :02d}-model\_01-model\_states.pt')

#

Final normalization layer

683ifself.n\_layers+1inself.filter\_layers:684withmonit.section('Final norm layer'):685layer=self.\_prepare\_layer(self.\_create\_final\_norm\_layer())686yieldlayer,('layer\_47-model\_00-model\_states.pt','layer\_47-model\_01-model\_states.pt')

#

Readout layer

689ifself.n\_layers+2inself.filter\_layers:690withmonit.section('Readout layer'):691layer=self.\_prepare\_layer(self.\_create\_readout\_layer())692yieldlayer,('layer\_48-model\_00-model\_states.pt','layer\_48-model\_01-model\_states.pt')693694forkinself.pre\_created\_layers.keys():695self.pre\_created\_layers[k]=None

#

Returns the total number of layers

697@property698deftotal\_layers(self):

#

702returnself.n\_layers+3

#

Generator to load layers

[email protected]\_grad()705defload(self)-\>Generator[NeoXModule,None,None]:

#

709withmonit.section("Layers"):710fori,(layer,files)inenumerate(self.get\_layers()):711iffilesisnotNone:712layer.load\_state(\*checkpoint.load\_checkpoint\_files(files))713714layer=self.post\_load\_prepare(layer)715716monit.progress(min(0.99,(i+1)/self.total\_layers))717yieldlayer

labml.ai