Back to Peft

Peft Poly Seq2seq With Generate

examples/poly/peft_poly_seq2seq_with_generate.ipynb

0.19.17.4 KB
Original Source
python
%env CUDA_VISIBLE_DEVICES=0  # force using CUDA GPU device 0
%env ZE_AFFINITY_MASK=0  # force using Intel XPU device 0
%env TOKENIZERS_PARALLELISM=false

Initialize PolyModel

python
import torch
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    default_data_collator,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)
from datasets import load_dataset, concatenate_datasets
from peft import PolyConfig, get_peft_model, TaskType, PeftModel, PeftConfig

model_name_or_path = "google/flan-t5-xl"

r = 8  # rank of lora in poly
n_tasks = 4  # number of tasks
n_skills = 2  # number of skills (loras)
n_splits = 4  # number of heads

batch_size = 8
lr = 5e-5
num_epochs = 8
python
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, trust_remote_code=True)
python
peft_config = PolyConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    poly_type="poly",
    r=r,
    n_tasks=n_tasks,
    n_skills=n_skills,
    n_splits=n_splits,
)

model = get_peft_model(base_model, peft_config)
model.print_trainable_parameters()

Prepare datasets

For this example, we selected four SuperGLUE benchmark datasets: boolq, multirc, rte, and wic, each with a training set of 1,000 examples and an evaluation set of 100 examples.

python
# boolq
boolq_dataset = (
    load_dataset("super_glue", "boolq")
    .map(
        lambda x: {
            "input": f"{x['passage']}\nQuestion: {x['question']}\nA. Yes\nB. No\nAnswer:",
            # 0 - False
            # 1 - True
            "output": ["B", "A"][int(x["label"])],
            "task_name": "boolq",
        }
    )
    .select_columns(["input", "output", "task_name"])
)
print("boolq example: ")
print(boolq_dataset["train"][0])

# multirc
multirc_dataset = (
    load_dataset("super_glue", "multirc")
    .map(
        lambda x: {
            "input": (
                f"{x['paragraph']}\nQuestion: {x['question']}\nAnswer: {x['answer']}\nIs it"
                " true?\nA. Yes\nB. No\nAnswer:"
            ),
            # 0 - False
            # 1 - True
            "output": ["B", "A"][int(x["label"])],
            "task_name": "multirc",
        }
    )
    .select_columns(["input", "output", "task_name"])
)
print("multirc example: ")
print(multirc_dataset["train"][0])

# rte
rte_dataset = (
    load_dataset("super_glue", "rte")
    .map(
        lambda x: {
            "input": (
                f"{x['premise']}\n{x['hypothesis']}\nIs the sentence below entailed by the"
                " sentence above?\nA. Yes\nB. No\nAnswer:"
            ),
            # 0 - entailment
            # 1 - not_entailment
            "output": ["A", "B"][int(x["label"])],
            "task_name": "rte",
        }
    )
    .select_columns(["input", "output", "task_name"])
)
print("rte example: ")
print(rte_dataset["train"][0])

# wic
wic_dataset = (
    load_dataset("super_glue", "wic")
    .map(
        lambda x: {
            "input": (
                f"Sentence 1: {x['sentence1']}\nSentence 2: {x['sentence2']}\nAre '{x['word']}'"
                " in the above two sentences the same?\nA. Yes\nB. No\nAnswer:"
            ),
            # 0 - False
            # 1 - True
            "output": ["B", "A"][int(x["label"])],
            "task_name": "wic",
        }
    )
    .select_columns(["input", "output", "task_name"])
)
print("wic example: ")
print(wic_dataset["train"][0])
python
# define a task2id map
TASK2ID = {
    "boolq": 0,
    "multirc": 1,
    "rte": 2,
    "wic": 3,
}


def tokenize(examples):
    inputs, targets = examples["input"], examples["output"]
    features = tokenizer(inputs, max_length=512, padding="max_length", truncation=True, return_tensors="pt")
    labels = tokenizer(targets, max_length=2, padding="max_length", truncation=True, return_tensors="pt")
    labels = labels["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100
    features["labels"] = labels
    features["task_ids"] = torch.tensor([[TASK2ID[t]] for t in examples["task_name"]]).long()
    return features
python
def get_superglue_dataset(
    split="train",
    n_samples=500,
):
    ds = concatenate_datasets(
        [
            boolq_dataset[split].shuffle().select(range(n_samples)),
            multirc_dataset[split].shuffle().select(range(n_samples)),
            rte_dataset[split].shuffle().select(range(n_samples)),
            wic_dataset[split].shuffle().select(range(n_samples)),
        ]
    )
    ds = ds.map(
        tokenize,
        batched=True,
        remove_columns=["input", "output", "task_name"],
        load_from_cache_file=False,
    )
    return ds

As a toy example, we only select 1,000 from each subdataset for training and 100 each for eval.

python
superglue_train_dataset = get_superglue_dataset(split="train", n_samples=1000)
superglue_eval_dataset = get_superglue_dataset(split="test", n_samples=100)

Train and evaluate

python
# training and evaluation
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    preds = [[i for i in seq if i != -100] for seq in preds]
    labels = [[i for i in seq if i != -100] for seq in labels]
    preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    correct = 0
    total = 0
    for pred, true in zip(preds, labels):
        if pred.strip() == true.strip():
            correct += 1
        total += 1
    accuracy = correct / total
    return {"accuracy": accuracy}


training_args = Seq2SeqTrainingArguments(
    "output",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=lr,
    num_train_epochs=num_epochs,
    eval_strategy="epoch",
    logging_strategy="epoch",
    save_strategy="no",
    report_to=[],
    predict_with_generate=True,
    generation_max_length=2,
    remove_unused_columns=False,
)
trainer = Seq2SeqTrainer(
    model=model,
    processing_class=tokenizer,
    args=training_args,
    train_dataset=superglue_train_dataset,
    eval_dataset=superglue_eval_dataset,
    data_collator=default_data_collator,
    compute_metrics=compute_metrics,
)
trainer.train()
python
# saving model
model_name_or_path = "google/flan-t5-xl"
peft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"
model.save_pretrained(peft_model_id)
python
!ls -lh $peft_model_id

Load and infer

python
device_type = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
device = f"{device_type}:0" if device_type != "cpu" else "cpu"
python
peft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"

config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, peft_model_id)
model = model.to(device)
model = model.eval()
python
i = 5
inputs = tokenizer(rte_dataset["validation"]["input"][i], return_tensors="pt")
inputs["task_ids"] = torch.LongTensor([TASK2ID["rte"]])
inputs = {k: v.to(device) for k, v in inputs.items()}
print(rte_dataset["validation"]["input"][i])
print(rte_dataset["validation"]["output"][i])
print(inputs)

with torch.no_grad():
    outputs = model.generate(**inputs, max_new_tokens=2)
    print(outputs[0])
    print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])