examples/sequence_classification/LoRA-torchao-8bit.ipynb
import argparse
import os
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torchao.quantization import Int8WeightOnlyConfig
from peft import (
get_peft_config,
get_peft_model,
get_peft_model_state_dict,
set_peft_model_state_dict,
LoraConfig,
PeftType,
PrefixTuningConfig,
PromptEncoderConfig,
)
import evaluate
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer, TorchAoConfig, get_linear_schedule_with_warmup, set_seed
from tqdm import tqdm
batch_size = 16
model_name_or_path = "google/gemma-2-2b"
task = "mrpc"
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
num_epochs = 5
lr = 2e-5
lora_rank = 16
lora_alpha = 32
lora_dropout = 0.1
if any(k in model_name_or_path for k in ("gpt", "opt", "bloom")):
padding_side = "left"
else:
padding_side = "right"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)
if getattr(tokenizer, "pad_token_id") is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
datasets = load_dataset("glue", task)
metric = evaluate.load("glue", task)
def tokenize_function(examples):
# max_length=None => use the model max length (it's actually the default)
outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
return outputs
tokenized_datasets = datasets.map(
tokenize_function,
batched=True,
remove_columns=["idx", "sentence1", "sentence2"],
)
# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
# transformers library
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
def collate_fn(examples):
return tokenizer.pad(examples, padding="longest", return_tensors="pt")
# Instantiate dataloaders.
train_dataloader = DataLoader(
tokenized_datasets["train"],
shuffle=True,
collate_fn=collate_fn,
batch_size=batch_size,
)
eval_dataloader = DataLoader(
tokenized_datasets["validation"],
shuffle=False,
collate_fn=collate_fn,
batch_size=batch_size,
)
quant_config = TorchAoConfig(quant_type=Int8WeightOnlyConfig())
model = AutoModelForSequenceClassification.from_pretrained(
model_name_or_path, return_dict=True, device_map=0, dtype=torch.bfloat16, quantization_config=quant_config
)
peft_config = LoraConfig(
task_type="SEQ_CLS",
r=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=["q_proj", "v_proj"],
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
optimizer = AdamW(params=model.parameters(), lr=lr)
# Instantiate scheduler
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=0.06 * (len(train_dataloader) * num_epochs),
num_training_steps=(len(train_dataloader) * num_epochs),
)
model.config.use_cache = False
model.to(device)
%%time
for epoch in range(1, num_epochs + 1):
model.train()
train_losses = []
for step, batch in enumerate(tqdm(train_dataloader)):
batch.to(device)
outputs = model(**batch)
loss = outputs.loss
if not torch.isfinite(loss):
raise ValueError("non-finite loss encountered")
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
train_losses.append(loss.item())
model.eval()
for step, batch in enumerate(tqdm(eval_dataloader)):
batch.to(device)
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
predictions, references = predictions, batch["labels"]
metric.add_batch(
predictions=predictions,
references=references,
)
eval_metric = metric.compute()
train_loss = sum(train_losses) / len(train_losses)
print(f"epoch {epoch} | train loss {train_loss:.4f} |", eval_metric)
# memory: 18098MiB