docs/lora/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/lora/ init.py)
This is an implementation of Low-Rank Adaptation (LoRA) in PyTorch.
Low-Rank Adaptation (LoRA) freezes pre-trained model weights and injects trainable rank decomposition matrices into each layer of the transformer. This makes it possible to efficiently fine-tune large language models by reducing trainable parameters by a large factor.
Here's the training code for training a GPT2 model with LoRA on Tiny Shakespeare dataset.
24importtorch25importtorch.nnasnn
LoRA linear layer adds a low-rank decomposition to the pre-trained weight matrix (W0∈Rd×k) of the linear layer.
W0+ΔW=W0+BA
, where B∈Rd×r, A∈Rr×k, and the rank r≪min(d,k).
All parameters are frozen except A and B.
ΔW is initialized to be zero at the beginning of the training.
They multiple xΔWT by rα where α is a hyper-parameter. Once α is tuned it can be kept the same when varying r.
28classLinear(nn.Module):
in_features is the number of input features of the linear layerout_features is the number of output features of the linear layerbias is a flag indicating if there is a bias parameterr is the rank of the decomposition ralpha is the scaling factor α49def\_\_init\_\_(self,in\_features:int,out\_features:int,bias:bool,50r:int,alpha:int=None):
58super().\_\_init\_\_()
Set α=r is not provided. i.e. make the scaling factor rα=1.
61ifalphaisNone:62alpha=r
The pre-trained weight W0
65self.weight=nn.Parameter(torch.empty((out\_features,in\_features)))
Freeze it
67self.weight.requires\_grad=False6869ifbias:
Bias parameter b0 (also frozen)
71self.bias=nn.Parameter(torch.empty(out\_features))72self.bias.requires\_grad=False73else:
No bias parameter
75self.bias=None
scaling factor rα
78self.scaling=alpha/r
Matrix A∈Rr×k
80self.lora\_a=nn.Parameter(torch.empty((r,in\_features)))
Matrix B∈Rd×r, we keep A and B transposed
82self.lora\_b=nn.Parameter(torch.empty((out\_features,r)))8384withtorch.no\_grad():
Initialize A similar to a weight matrix in a normal linear layer
86nn.init.kaiming\_uniform\_(self.lora\_a,a=5\*\*0.5)
Initialize B to 0 so that ΔW=BA is 0 at initialization
88nn.init.zeros\_(self.lora\_b)
90defforward(self,x:torch.Tensor):
Compute xW0T+b0
92result=nn.functional.linear(x,self.weight,bias=self.bias)
Add rαxΔWT=rαx(BA)T=rαxATBT
95result+=([email protected]\[email protected]\_b.T)\*self.scaling
98returnresult
Similar to LoRA linear layer this adds a low-rank decomposition to the pre-trained embedding weights matrix (W0∈Rd×k).
W0+ΔW=W0+BA
101classEmbedding(nn.Module):
num_embeddings is the number of embeddingsembedding_dim is the number embedding dimensionsr is the rank of the decomposition ralpha is the scaling factor α111def\_\_init\_\_(self,num\_embeddings:int,embedding\_dim:int,112r:int,alpha:int=None):
120super().\_\_init\_\_()
Set α=r is not provided. i.e. make the scaling factor rα=1.
123ifalphaisNone:124alpha=r
The pre-trained embedding weights W0T (frozen)
127self.weight=nn.Parameter(torch.empty((num\_embeddings,embedding\_dim)))128self.weight.requires\_grad=False
scaling factor rα
131self.scaling=alpha/r
Matrix A∈Rr×k
133self.lora\_a=nn.Parameter(torch.empty((r,num\_embeddings)))
Matrix B∈Rd×r
135self.lora\_b=nn.Parameter(torch.empty((embedding\_dim,r)))136137withtorch.no\_grad():
Initialize A with a normal distribution
139nn.init.normal\_(self.lora\_a)
Initialize B to 0 so that ΔW=BA is 0 at initialization
141nn.init.zeros\_(self.lora\_b)
143defforward(self,x:torch.Tensor):
Compute the embeddings onehot(x)W0
145result=nn.functional.embedding(x,self.weight)
Add rαonehot(x)ΔWT=rαonehot(x)ATBT
148result+=(nn.functional.embedding(x,self.lora\_a.T)@self.lora\_b.T)\*self.scaling
151returnresult