docs/lora/gpt2.html
Here's the training code for training a GPT2 model with LoRA on Tiny Shakespeare dataset.
13importtorch14importtorch.nnasnn1516fromlabml\_nn.loraimportLinear,Embedding
19classFFN(nn.Module):
d_model is the number of dimensionsd_ff is the size of the hidden dimensionr is the lora rank24def\_\_init\_\_(self,d\_model:int,d\_ff:int,r:int):
30super().\_\_init\_\_()
The linear layers and the activation
33self.linear\_in=Linear(d\_model,d\_ff,r=r,bias=True)34self.linear\_out=Linear(d\_ff,d\_model,r=r,bias=True)35self.act=nn.GELU()
x is the embeddings tensor with shape [batch_size, seq_len, d_model]37defforward(self,x:torch.Tensor)-\>torch.Tensor:
41x=self.linear\_in(x)42x=self.act(x)43x=self.linear\_out(x)44returnx
47classMultiHeadAttention(nn.Module):
d_model is the number of dimensions in the embeddingsn_heads is the number of headsr is the lora rank52def\_\_init\_\_(self,d\_model:int,n\_heads:int,r:int):
58super().\_\_init\_\_()59self.d\_model=d\_model60self.n\_heads=n\_heads61self.d\_head=d\_model//n\_heads
Linear transformation for QKV
64self.qkv\_projection=Linear(d\_model,d\_model\*3,r=r,bias=True)
Output projection
66self.output\_projection=Linear(d\_model,d\_model,r=r,bias=True)
x is the tensor with shape [batch_size, seq_len, d_model]68def\_split\_heads(self,x:torch.Tensor):
Split last dimension to [n_heads, d_head]
73x=x.view(x.shape[:-1]+(self.n\_heads,self.d\_head))
Reorder to [batch_size, head, seq_length, d_head]
75returnx.permute(0,2,1,3)
x is the embeddings tensor with shape [batch_size, seq_len, d_model]77defforward(self,x:torch.Tensor)-\>torch.Tensor:
81batch\_size,seq\_length,\_=x.shape
Get query, key and value
84q,k,v=self.qkv\_projection(x).split(self.d\_model,dim=-1)
Transform them from shape [batch_size, seq_len, d_model] to [batch_size, head, seq_length, d_head]
87q=self.\_split\_heads(q)88k=self.\_split\_heads(k)89v=self.\_split\_heads(v)
Apply causal attention
92attn\_output=torch.nn.functional.scaled\_dot\_product\_attention(q,k,v,is\_causal=True)
Transform them from shape [batch_size, head, seq_length, d_head] to [batch_size, seq_len, d_model]
95attn\_output=attn\_output.permute(0,2,1,3).reshape(batch\_size,seq\_length,self.d\_model)
Final project
98returnself.output\_projection(attn\_output)
101classBlock(nn.Module):
d_model is the number of dimensions in the embeddingsn_heads is the number of headslayer_norm_epsilon is the layer norm epsilonr is the lora rank106def\_\_init\_\_(self,d\_model:int,n\_heads:int,layer\_norm\_epsilon:float,r:int):
113super().\_\_init\_\_()
Attention pre-normalization layer
115self.attn\_norm=nn.LayerNorm(d\_model,eps=layer\_norm\_epsilon)
Attention layer
117self.attn=MultiHeadAttention(d\_model,n\_heads,r)
FFN pre-normalization layer
119self.ffn\_norm=nn.LayerNorm(d\_model,eps=layer\_norm\_epsilon)
Feed-forward network
121self.ffn=FFN(d\_model,d\_model\*4,r)
x is the embeddings tensor with shape [batch_size, seq_len, d_model]123defforward(self,x:torch.Tensor)-\>torch.Tensor:
Attention
128x=x+self.attn(self.attn\_norm(x))
FFN
130x=x+self.ffn(self.ffn\_norm(x))131132returnx
135classGPTModel(nn.Module):
d_model is the number of dimensions in the embeddingsn_heads is the number of attention headsn_layers is the number of decoder layersn_positions is the number of positional embeddingslayer_norm_epsilon is the layer norm epsilonvocab_size is the vocabulary sizer is the lora rank140def\_\_init\_\_(self,\*,d\_model:int,141n\_heads:int,n\_layers:int,142n\_positions:int,143layer\_norm\_epsilon:float,144vocab\_size:int,r:int):
154super().\_\_init\_\_()
Token and absolute positional embeddings
157self.token\_embedding=Embedding(vocab\_size,d\_model,r=r)158self.position\_embedding=Embedding(n\_positions,d\_model,r=r)
Decoder blocks
161self.blocks=nn.ModuleList([Block(d\_model,n\_heads,layer\_norm\_epsilon,r=r)162for\_inrange(n\_layers)])
Final layer norm
165self.final\_norm=nn.LayerNorm(d\_model,eps=layer\_norm\_epsilon)
Projection layer to logit space
167self.lm\_head=Linear(d\_model,vocab\_size,r=r,bias=False)
input_ids has shape [batch_size, seq_len]169defforward(self,input\_ids:torch.Tensor):
173batch\_size,seq\_len=input\_ids.shape
Get token embeddings
176token\_embeddings=self.token\_embedding(input\_ids)
Get position ids
178position\_ids=torch.arange(seq\_len,device=input\_ids.device)[None,:]
Get position embeddings
180position\_embeddings=self.position\_embedding(position\_ids)
Add position embeddings
183x=token\_embeddings+position\_embeddings
Run through transformer blocks
186forblockinself.blocks:187x=block(x)
Final normalization
190x=self.final\_norm(x)
Get logits from projection layer
192returnself.lm\_head(x)