Back to Annotated Deep Learning Paper Implementations

Low-Rank Adaptation (LoRA)

docs/lora/index.html

latest4.9 KB
Original Source

homelora

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/lora/ init.py)

#

Low-Rank Adaptation (LoRA)

This is an implementation of Low-Rank Adaptation (LoRA) in PyTorch.

Low-Rank Adaptation (LoRA) freezes pre-trained model weights and injects trainable rank decomposition matrices into each layer of the transformer. This makes it possible to efficiently fine-tune large language models by reducing trainable parameters by a large factor.

Here's the training code for training a GPT2 model with LoRA on Tiny Shakespeare dataset.

24importtorch25importtorch.nnasnn

#

LoRA Linear Layer

LoRA linear layer adds a low-rank decomposition to the pre-trained weight matrix (W0​∈Rd×k) of the linear layer.

W0​+ΔW=W0​+BA

, where B∈Rd×r, A∈Rr×k, and the rank r≪min(d,k).

All parameters are frozen except A and B.

ΔW is initialized to be zero at the beginning of the training.

They multiple xΔWT by rα​ where α is a hyper-parameter. Once α is tuned it can be kept the same when varying r.

28classLinear(nn.Module):

#

  • in_features is the number of input features of the linear layer
  • out_features is the number of output features of the linear layer
  • bias is a flag indicating if there is a bias parameter
  • r is the rank of the decomposition r
  • alpha is the scaling factor α
49def\_\_init\_\_(self,in\_features:int,out\_features:int,bias:bool,50r:int,alpha:int=None):

#

58super().\_\_init\_\_()

#

Set α=r is not provided. i.e. make the scaling factor rα​=1.

61ifalphaisNone:62alpha=r

#

The pre-trained weight W0​

65self.weight=nn.Parameter(torch.empty((out\_features,in\_features)))

#

Freeze it

67self.weight.requires\_grad=False6869ifbias:

#

Bias parameter b0​ (also frozen)

71self.bias=nn.Parameter(torch.empty(out\_features))72self.bias.requires\_grad=False73else:

#

No bias parameter

75self.bias=None

#

scaling factor rα​

78self.scaling=alpha/r

#

Matrix A∈Rr×k

80self.lora\_a=nn.Parameter(torch.empty((r,in\_features)))

#

Matrix B∈Rd×r, we keep A and B transposed

82self.lora\_b=nn.Parameter(torch.empty((out\_features,r)))8384withtorch.no\_grad():

#

Initialize A similar to a weight matrix in a normal linear layer

86nn.init.kaiming\_uniform\_(self.lora\_a,a=5\*\*0.5)

#

Initialize B to 0 so that ΔW=BA is 0 at initialization

88nn.init.zeros\_(self.lora\_b)

#

90defforward(self,x:torch.Tensor):

#

Compute xW0​T+b0​

92result=nn.functional.linear(x,self.weight,bias=self.bias)

#

Add rα​xΔWT=rα​x(BA)T=rα​xATBT

95result+=([email protected]\[email protected]\_b.T)\*self.scaling

#

98returnresult

#

LoRA Embedding Layer

Similar to LoRA linear layer this adds a low-rank decomposition to the pre-trained embedding weights matrix (W0​∈Rd×k).

W0​+ΔW=W0​+BA

101classEmbedding(nn.Module):

#

  • num_embeddings is the number of embeddings
  • embedding_dim is the number embedding dimensions
  • r is the rank of the decomposition r
  • alpha is the scaling factor α
111def\_\_init\_\_(self,num\_embeddings:int,embedding\_dim:int,112r:int,alpha:int=None):

#

120super().\_\_init\_\_()

#

Set α=r is not provided. i.e. make the scaling factor rα​=1.

123ifalphaisNone:124alpha=r

#

The pre-trained embedding weights W0​T (frozen)

127self.weight=nn.Parameter(torch.empty((num\_embeddings,embedding\_dim)))128self.weight.requires\_grad=False

#

scaling factor rα​

131self.scaling=alpha/r

#

Matrix A∈Rr×k

133self.lora\_a=nn.Parameter(torch.empty((r,num\_embeddings)))

#

Matrix B∈Rd×r

135self.lora\_b=nn.Parameter(torch.empty((embedding\_dim,r)))136137withtorch.no\_grad():

#

Initialize A with a normal distribution

139nn.init.normal\_(self.lora\_a)

#

Initialize B to 0 so that ΔW=BA is 0 at initialization

141nn.init.zeros\_(self.lora\_b)

#

143defforward(self,x:torch.Tensor):

#

Compute the embeddings onehot(x)W0​

145result=nn.functional.embedding(x,self.weight)

#

Add rα​onehot(x)ΔWT=rα​onehot(x)ATBT

148result+=(nn.functional.embedding(x,self.lora\_a.T)@self.lora\_b.T)\*self.scaling

#

151returnresult

labml.ai