Back to Annotated Deep Learning Paper Implementations

GPT-2 with LoRA modules

docs/lora/gpt2.html

latest6.3 KB
Original Source

homelora

View code on Github

#

GPT-2 with LoRA modules

Here's the training code for training a GPT2 model with LoRA on Tiny Shakespeare dataset.

13importtorch14importtorch.nnasnn1516fromlabml\_nn.loraimportLinear,Embedding

#

Feedforward Network

19classFFN(nn.Module):

#

  • d_model is the number of dimensions
  • d_ff is the size of the hidden dimension
  • r is the lora rank
24def\_\_init\_\_(self,d\_model:int,d\_ff:int,r:int):

#

30super().\_\_init\_\_()

#

The linear layers and the activation

33self.linear\_in=Linear(d\_model,d\_ff,r=r,bias=True)34self.linear\_out=Linear(d\_ff,d\_model,r=r,bias=True)35self.act=nn.GELU()

#

  • x is the embeddings tensor with shape [batch_size, seq_len, d_model]
37defforward(self,x:torch.Tensor)-\>torch.Tensor:

#

41x=self.linear\_in(x)42x=self.act(x)43x=self.linear\_out(x)44returnx

#

Multi-Head Attention

47classMultiHeadAttention(nn.Module):

#

  • d_model is the number of dimensions in the embeddings
  • n_heads is the number of heads
  • r is the lora rank
52def\_\_init\_\_(self,d\_model:int,n\_heads:int,r:int):

#

58super().\_\_init\_\_()59self.d\_model=d\_model60self.n\_heads=n\_heads61self.d\_head=d\_model//n\_heads

#

Linear transformation for QKV

64self.qkv\_projection=Linear(d\_model,d\_model\*3,r=r,bias=True)

#

Output projection

66self.output\_projection=Linear(d\_model,d\_model,r=r,bias=True)

#

  • x is the tensor with shape [batch_size, seq_len, d_model]
68def\_split\_heads(self,x:torch.Tensor):

#

Split last dimension to [n_heads, d_head]

73x=x.view(x.shape[:-1]+(self.n\_heads,self.d\_head))

#

Reorder to [batch_size, head, seq_length, d_head]

75returnx.permute(0,2,1,3)

#

  • x is the embeddings tensor with shape [batch_size, seq_len, d_model]
77defforward(self,x:torch.Tensor)-\>torch.Tensor:

#

81batch\_size,seq\_length,\_=x.shape

#

Get query, key and value

84q,k,v=self.qkv\_projection(x).split(self.d\_model,dim=-1)

#

Transform them from shape [batch_size, seq_len, d_model] to [batch_size, head, seq_length, d_head]

87q=self.\_split\_heads(q)88k=self.\_split\_heads(k)89v=self.\_split\_heads(v)

#

Apply causal attention

92attn\_output=torch.nn.functional.scaled\_dot\_product\_attention(q,k,v,is\_causal=True)

#

Transform them from shape [batch_size, head, seq_length, d_head] to [batch_size, seq_len, d_model]

95attn\_output=attn\_output.permute(0,2,1,3).reshape(batch\_size,seq\_length,self.d\_model)

#

Final project

98returnself.output\_projection(attn\_output)

#

Decoder block

101classBlock(nn.Module):

#

  • d_model is the number of dimensions in the embeddings
  • n_heads is the number of heads
  • layer_norm_epsilon is the layer norm epsilon
  • r is the lora rank
106def\_\_init\_\_(self,d\_model:int,n\_heads:int,layer\_norm\_epsilon:float,r:int):

#

113super().\_\_init\_\_()

#

Attention pre-normalization layer

115self.attn\_norm=nn.LayerNorm(d\_model,eps=layer\_norm\_epsilon)

#

Attention layer

117self.attn=MultiHeadAttention(d\_model,n\_heads,r)

#

FFN pre-normalization layer

119self.ffn\_norm=nn.LayerNorm(d\_model,eps=layer\_norm\_epsilon)

#

Feed-forward network

121self.ffn=FFN(d\_model,d\_model\*4,r)

#

  • x is the embeddings tensor with shape [batch_size, seq_len, d_model]
123defforward(self,x:torch.Tensor)-\>torch.Tensor:

#

Attention

128x=x+self.attn(self.attn\_norm(x))

#

FFN

130x=x+self.ffn(self.ffn\_norm(x))131132returnx

#

GPT2 Model

135classGPTModel(nn.Module):

#

  • d_model is the number of dimensions in the embeddings
  • n_heads is the number of attention heads
  • n_layers is the number of decoder layers
  • n_positions is the number of positional embeddings
  • layer_norm_epsilon is the layer norm epsilon
  • vocab_size is the vocabulary size
  • r is the lora rank
140def\_\_init\_\_(self,\*,d\_model:int,141n\_heads:int,n\_layers:int,142n\_positions:int,143layer\_norm\_epsilon:float,144vocab\_size:int,r:int):

#

154super().\_\_init\_\_()

#

Token and absolute positional embeddings

157self.token\_embedding=Embedding(vocab\_size,d\_model,r=r)158self.position\_embedding=Embedding(n\_positions,d\_model,r=r)

#

Decoder blocks

161self.blocks=nn.ModuleList([Block(d\_model,n\_heads,layer\_norm\_epsilon,r=r)162for\_inrange(n\_layers)])

#

Final layer norm

165self.final\_norm=nn.LayerNorm(d\_model,eps=layer\_norm\_epsilon)

#

Projection layer to logit space

167self.lm\_head=Linear(d\_model,vocab\_size,r=r,bias=False)

#

  • input_ids has shape [batch_size, seq_len]
169defforward(self,input\_ids:torch.Tensor):

#

173batch\_size,seq\_len=input\_ids.shape

#

Get token embeddings

176token\_embeddings=self.token\_embedding(input\_ids)

#

Get position ids

178position\_ids=torch.arange(seq\_len,device=input\_ids.device)[None,:]

#

Get position embeddings

180position\_embeddings=self.position\_embedding(position\_ids)

#

Add position embeddings

183x=token\_embeddings+position\_embeddings

#

Run through transformer blocks

186forblockinself.blocks:187x=block(x)

#

Final normalization

190x=self.final\_norm(x)

#

Get logits from projection layer

192returnself.lm\_head(x)

labml.ai