Back to Peft

Initializing weights with LoftQ by replacing LoRA weights in-place

examples/loftq_finetuning/LoftQ_weight_replacement.ipynb

0.19.16.1 KB
Original Source

Initializing weights with LoftQ by replacing LoRA weights in-place

This notebook shows how to apply LoftQ initialization on our QLoRA model.

In short, the idea behind LoftQ is the following. When we use QLoRA, i.e. we quantize the base model with bitsandbytes to save memory, and then train LoRA weights on top of this base model, we expect a certain performance gap. This is partly due to the fact that quantization is onyl an approximation of the "real" weights and thus introduces a quantization error. By default, LoRA weights are initialized such that they are a no-op at the start of the training. However, we can instead initialize them so that they minimize the quantization error. This is the idea behind LoftQ.

Note that this only influences the initialization of the model. Everything that follows stays the same as always.

Imports

python
import os
import torch
python
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
python
from peft import get_peft_model, LoraConfig, replace_lora_weights_loftq

Functions

python
def get_mae(x, y):
    return (x - y).abs().mean()


def get_mse(x, y):
    return torch.pow(x - y, 2).mean()


def error_report(x, y):
    mae = get_mae(x, y)
    mse = get_mse(x, y)
    print(
        f"Mean absolute error: {mae:>8.5f}\n"
        f"Mean squared error:  {mse:>8.5f}"
    )

Base model

First, let's load a base model and calculate some logits. These logits are the baseline, i.e. we try to match their values as best as possible. We only need these logits for demonstration purposes. In practice, it is not necessary to load the non-quantized weights to apply LoftQ initialization.

Note: We have to choose a model with a model.safetensors file. As PyTorch checkpoints (pickle) cannot be loaded lazily, we have to use safetensors. If those don't exist for your model, save the pretrained model as a safetensors file using safe_pretrained and pass the model path to replace_lora_weights_loftq.

python
model_id = "bigscience/bloomz-560m"
python
tokenizer = AutoTokenizer.from_pretrained(model_id)
python
model = AutoModelForCausalLM.from_pretrained(model_id)
python
s = """Beautiful is better than ugly.
Explicit is better than implicit.
Simple is better than complex.
Complex is better than complicated.
Flat is better than nested.
Sparse is better than dense.
Readability counts.
Special cases aren't special enough to break the rules.
Although practicality beats purity.
Errors should never pass silently.
Unless explicitly silenced.
In the face of ambiguity, refuse the temptation to guess.
There should be one-- and preferably only one --obvious way to do it.
Although that way may not be obvious at first unless you're Dutch.
Now is better than never.
Although never is often better than *right* now.
If the implementation is hard to explain, it's a bad idea.
If the implementation is easy to explain, it may be a good idea.
Namespaces are one honking great idea -- let's do more of those!"""
python
inputs = tokenizer(s.splitlines(), return_tensors="pt", padding=True)

Our baseline logits:

python
logits_base = model(**inputs).logits

Normal LoRA model

Now we load the model quantized with bitsandbytes. For now, only 4bit is supported.

python
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
)
python
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)

Next we create a LoRA model using PEFT and compute the logits of that model.

python
lora_config = LoraConfig(task_type="CAUSAL_LM", target_modules="all-linear")
python
peft_model = get_peft_model(model, lora_config)
python
logits_lora = peft_model(**inputs).logits

Let's check the influence of the quantization error on our logits:

python
error_report(logits_base, logits_lora)

LoftQ

Next, let's use LoftQ initialization and see if it helps reduce the error.

python
replace_lora_weights_loftq(peft_model)
python
logits_loftq = peft_model(**inputs).logits
python
error_report(logits_base, logits_loftq)

We can see that LoftQ initialization helped a little bit, but the difference is not huge.

LoftQ with callback

To help with this, let's write a small callback function and pass it to replace_lora_weights_loftq. What this function does is that each time one weight is being replaced with LoftQ-initialized weights, we perform a test if the quantization error is actually reduced. If it it is not, we roll back the replacement. This way, we keep only those replacements that improve the results.

python
# Since PEFT has modified the base model, we should reload it
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
python
peft_model = get_peft_model(model, lora_config)
python
current_mse = float("inf")
python
def my_callback(model, module_name):
    """Callable to replace weights with LoFTQ if the mse is lower than the current best one."""
    global current_mse

    logits = model(**inputs).logits
    mse = get_mse(logits_base, logits)
    if mse < current_mse:
        current_mse = mse
        print(f"MSE improved for module {module_name}")
        return True
    print(f"MSE did not improve for module {module_name}")
    return False
python
replace_lora_weights_loftq(peft_model, callback=my_callback)
python
logits_loftq_callback = peft_model(**inputs).logits
python
error_report(logits_base, logits_loftq_callback)

We can see that applying LoftQ with the help of the callback reduced the error quite significantly.

Applying LoftQ multiple times

It is possible to run replace_lora_weights_loftq multiple times on the same model when using the callback.

python
replace_lora_weights_loftq(peft_model, callback=my_callback)
python
logits_loftq_callback_twice = peft_model(**inputs).logits
python
error_report(logits_base, logits_loftq_callback_twice)

There are further gains, but they are not very big.