Back to Peft

Block-Diagonal LoRA for Eliminating Communication Overhead in Tensor Parallel LoRA Serving

examples/bdlora_finetuning/bdlora_peft_demo.ipynb

0.19.19.3 KB
Original Source

Block-Diagonal LoRA for Eliminating Communication Overhead in Tensor Parallel LoRA Serving

Introduction

Block-Diagonal LoRA (BD-LoRA) is a LoRA variant in which some LoRA factors are constrained to be block-diagonal. This allows faster serving by eliminating communication overheads when running inference on multiple GPUs. Despite the block-diagonal constraint, BD-LoRA is similarly performant to vanilla LoRA at similar parameter counts.

BD-LoRA is designed to be used with tensor parallelism, which means sharding the weights of a model among multiple GPUs. A popular sharding strategy is the Megatron Sharding Strategy. For two linear layers $W_1$, $W_2$ that follow each other (for example the up and down projections in a transformer MLP module), we will shard the first layer in a column-parallel way (which requires LoRA B to be block-diagonal) and the second layer in a row-parallel way (which requires LoRA A to be block-diagonal). For the attention module, this can be similarly achieved by taking the Q, K and V projections together as $W_1$ and the out projection as $W_2$, sharding accordingly. This sharding allows a compatible inference engine to distribute each block-diagonal shard over a a different GPU, cutting the need to communicate partial results among GPUs. In the image below, you can see the exact sharding strategy and how this saves computational efforts.

Paper: https://arxiv.org/html/2510.23346v1

<div> </div>

Performance, rank and parameter count

BD-LoRA achieves similar performance to LoRA (see image below, or the method_comparison folder in the peft repository root) at the same parameter count. However, as every other factor in BD-LoRA is block-diagonal, a BD-LoRA adapter will have less parameters than a LoRA adapter at the same rank. The performance of BD-LoRA is only competitive when the rank is then increased accordingly. We provide example code for rank-matching at the end of this example notebook.

<div> </div>
python
from peft.tuners import BdLoraConfig, LoraConfig
from peft import get_peft_model
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling, AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import torch

Quick Start

To use BD-LoRA, we can follow standard LoRA-training procedures. We only need to change the LoraConfig to a BdLoraConfig and specify which LoRA should be block-diagonal. As an example, we will train a LLama-Model in such a way that it can later benefit from inference speed-up as specified in the BD-LoRA paper. However, BD-LoRA can be used with all other models that follow a transformer architecture.

As explained in the introduction, we want to shard each module (MLP and attention) in an alternating fashion, first column-parallel with LoRA-B block-diagonal, then row-parallel with LoRA-A block-diagonal. Different from standard MLP modules, Llama also uses a gate projection, which we can fuse together with the up-projection.

Therefore, we want the following block-diagonal factors (following the naming convention from the Llama architecture):

  • LoRA-A Block-Diagonal (Row-parallel sharding): Out (out_proj), Down (down_proj)
  • LoRA-B Block-Diagonal (Column-parallel sharding): QKV (q_proj, k_proj, v_proj), Up+Gate (up_proj, gate_proj)

Additionally, we need to know on how many GPUs we want to serve before we start training, as this corresponds to the number of block we will use for each block-diagonal factor. For this experiment, we will use 2 blocks (equivalent to a tensor-parallelism degree of 2). Caveat: For a small model such as Llama 3.2-1B which we are using, one would use a single GPU for serving, and use TP=2 or TP=8 only for larger models, like Llama 3.1-8B or Llama 3.3-70B respectively.

python
model_name = "meta-llama/Llama-3.2-1B"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
python
target_modules=["q_proj", "v_proj", "k_proj", "up_proj", "gate_proj", "o_proj", "down_proj"]
# Set this equal to the number of GPUs you want to serve the model with later
nblocks = 2

bdlora_config = BdLoraConfig(
      target_modules_bd_a=["o_proj", "down_proj"],
      target_modules_bd_b=["q_proj", "v_proj", "k_proj", "up_proj", "gate_proj"],
      nblocks=nblocks
)

config = LoraConfig(
    r=96,
    # adjust target modules and the ...target_modules_bd attributes according to model architecture (for example renaming)
    target_modules=target_modules,
    use_bdlora=bdlora_config,
    lora_bias=False
)

peft_model = get_peft_model(model, config)
peft_model.print_trainable_parameters()

Training

We train the model for 10 steps, this training block is just intended to showcase how BD-LoRA integrates into other huggingface tools.

python
dataset = load_dataset("imdb", split="train[:1%]")

tokenizer.pad_token = tokenizer.eos_token
def tokenize(batch):
    return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=128)

dataset = dataset.map(tokenize, batched=True, remove_columns=["text"])
training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,
    warmup_steps=2,
    max_steps=10,
    learning_rate=2e-4,
    logging_steps=1,
)

trainer = Trainer(
    model=peft_model,
    args=training_args,
    train_dataset=dataset,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

peft_model.config.use_cache = False
trainer.train()

Saving Model

python
peft_model.save_pretrained("example_bd_lora_adapter")

Example Output

python
text = "The Batman Trilogy by Christopher Nolan"
inputs = tokenizer(text, return_tensors="pt").to(model.device)  

outputs = peft_model.generate(**inputs, max_length=50)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(decoded)

Investigating the shapes of LoRA Adapters

We can check out the adapter shapes to see if they follow the sharding patterns that we have discussed. To make the implementation more memory efficient, the block-diagonal matrices are not saved in a block-diagonal manner, but the blocks are stacked along the non-rank dimensions.

For example, if a layer is column sharded, such as the q-proj in Llama, then the LoRA-B factor is block-diagonal. Assume that the q-proj has layer weights (out_features, in_features), then LoRA-A will have shape (rank, in_features), and LoRA-B will have shape (out_features, rank / TP), which corresponds to TP blocks of shape (out_features/TP, rank/TP) each. This can be checked by investigating the weight shapes:

python
shape_base = list(peft_model.state_dict()['base_model.model.model.layers.0.self_attn.v_proj.base_layer.weight'].shape)
shape_a = list(peft_model.state_dict()['base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight'].shape)
shape_b = list(peft_model.state_dict()['base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight'].shape)
print(f"Base layer has shape:    [{shape_base[0]}, {shape_base[1]}]\nLoRA-A (vanilla):        [{shape_a[0]},  {shape_a[1]}]\nLoRA-B (block-diagonal): [{shape_b[0]}, {shape_b[1]}   ]")

Matching the rank

Assuming we want to achieve the same performance of a LoRA adapter of a given rank, at which rank would we have to train BD-LoRA? We can find this out by matching the number of trainable parameters. A simple iteration over the ranks of the BD-LoRA adapter is sufficient to do that:

python
def rank_to_params(r: int, bd_lora: bool, nblocks: int):
    model = AutoModelForCausalLM.from_pretrained(model_name)
    if bd_lora:
        config = LoraConfig(
            r=r,
            target_modules=target_modules,
            use_bdlora=bdlora_config,
            lora_bias=False
        )
    else:
        config = LoraConfig(
            r=r,
            # If you use a model different from Llama, change the settings below
            target_modules=target_modules,
            lora_bias=False
        )


    peft_model = get_peft_model(model, config)
    return peft_model.get_nb_trainable_parameters()[0]

r_orig = 64
r = r_orig
lora_nparams = rank_to_params(r, False, nblocks)
bdlora_nparams = 0
while bdlora_nparams < lora_nparams:
    r += nblocks
    bdlora_nparams = rank_to_params(r, True, nblocks)
# subtract nblocks again to be just under the parameter count of vanilla LoRA, following the original papers methodology
print(f"BD-LoRA rank to match vanilla LoRA performance at rank {r_orig}: {r-nblocks} at {lora_nparams} vanilla LoRA params and {rank_to_params(r-nblocks, True, nblocks)} BD-LoRA params.")

Integration with vLLM

Currently, vLLM has an experimental PR that allows you to use it with BD-LoRA. Clone the github repository and check out the commit of the pull request at https://github.com/vllm-project/vllm/pull/28136#. Then, install vLLM following the usual instructions: https://docs.vllm.ai/en/stable/getting_started/installation/. We assume that you have a hardware setup with at least 2 available GPUs.

We have included a script that starts a vLLM server with two BD-LoRA modules at vllm_server.bash (you might have to kill this jupyter server beforehand, as it is likely already using your GPU resources). Once the server has started, you can query it via python3 chat.py "Please write your message here." --target lora1.