examples/poly/peft_poly_seq2seq_with_generate.ipynb
%env CUDA_VISIBLE_DEVICES=0 # force using CUDA GPU device 0
%env ZE_AFFINITY_MASK=0 # force using Intel XPU device 0
%env TOKENIZERS_PARALLELISM=false
import torch
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
default_data_collator,
Seq2SeqTrainingArguments,
Seq2SeqTrainer,
)
from datasets import load_dataset, concatenate_datasets
from peft import PolyConfig, get_peft_model, TaskType, PeftModel, PeftConfig
model_name_or_path = "google/flan-t5-xl"
r = 8 # rank of lora in poly
n_tasks = 4 # number of tasks
n_skills = 2 # number of skills (loras)
n_splits = 4 # number of heads
batch_size = 8
lr = 5e-5
num_epochs = 8
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, trust_remote_code=True)
peft_config = PolyConfig(
task_type=TaskType.SEQ_2_SEQ_LM,
poly_type="poly",
r=r,
n_tasks=n_tasks,
n_skills=n_skills,
n_splits=n_splits,
)
model = get_peft_model(base_model, peft_config)
model.print_trainable_parameters()
For this example, we selected four SuperGLUE benchmark datasets: boolq, multirc, rte, and wic, each with a training set of 1,000 examples and an evaluation set of 100 examples.
# boolq
boolq_dataset = (
load_dataset("super_glue", "boolq")
.map(
lambda x: {
"input": f"{x['passage']}\nQuestion: {x['question']}\nA. Yes\nB. No\nAnswer:",
# 0 - False
# 1 - True
"output": ["B", "A"][int(x["label"])],
"task_name": "boolq",
}
)
.select_columns(["input", "output", "task_name"])
)
print("boolq example: ")
print(boolq_dataset["train"][0])
# multirc
multirc_dataset = (
load_dataset("super_glue", "multirc")
.map(
lambda x: {
"input": (
f"{x['paragraph']}\nQuestion: {x['question']}\nAnswer: {x['answer']}\nIs it"
" true?\nA. Yes\nB. No\nAnswer:"
),
# 0 - False
# 1 - True
"output": ["B", "A"][int(x["label"])],
"task_name": "multirc",
}
)
.select_columns(["input", "output", "task_name"])
)
print("multirc example: ")
print(multirc_dataset["train"][0])
# rte
rte_dataset = (
load_dataset("super_glue", "rte")
.map(
lambda x: {
"input": (
f"{x['premise']}\n{x['hypothesis']}\nIs the sentence below entailed by the"
" sentence above?\nA. Yes\nB. No\nAnswer:"
),
# 0 - entailment
# 1 - not_entailment
"output": ["A", "B"][int(x["label"])],
"task_name": "rte",
}
)
.select_columns(["input", "output", "task_name"])
)
print("rte example: ")
print(rte_dataset["train"][0])
# wic
wic_dataset = (
load_dataset("super_glue", "wic")
.map(
lambda x: {
"input": (
f"Sentence 1: {x['sentence1']}\nSentence 2: {x['sentence2']}\nAre '{x['word']}'"
" in the above two sentences the same?\nA. Yes\nB. No\nAnswer:"
),
# 0 - False
# 1 - True
"output": ["B", "A"][int(x["label"])],
"task_name": "wic",
}
)
.select_columns(["input", "output", "task_name"])
)
print("wic example: ")
print(wic_dataset["train"][0])
# define a task2id map
TASK2ID = {
"boolq": 0,
"multirc": 1,
"rte": 2,
"wic": 3,
}
def tokenize(examples):
inputs, targets = examples["input"], examples["output"]
features = tokenizer(inputs, max_length=512, padding="max_length", truncation=True, return_tensors="pt")
labels = tokenizer(targets, max_length=2, padding="max_length", truncation=True, return_tensors="pt")
labels = labels["input_ids"]
labels[labels == tokenizer.pad_token_id] = -100
features["labels"] = labels
features["task_ids"] = torch.tensor([[TASK2ID[t]] for t in examples["task_name"]]).long()
return features
def get_superglue_dataset(
split="train",
n_samples=500,
):
ds = concatenate_datasets(
[
boolq_dataset[split].shuffle().select(range(n_samples)),
multirc_dataset[split].shuffle().select(range(n_samples)),
rte_dataset[split].shuffle().select(range(n_samples)),
wic_dataset[split].shuffle().select(range(n_samples)),
]
)
ds = ds.map(
tokenize,
batched=True,
remove_columns=["input", "output", "task_name"],
load_from_cache_file=False,
)
return ds
As a toy example, we only select 1,000 from each subdataset for training and 100 each for eval.
superglue_train_dataset = get_superglue_dataset(split="train", n_samples=1000)
superglue_eval_dataset = get_superglue_dataset(split="test", n_samples=100)
# training and evaluation
def compute_metrics(eval_preds):
preds, labels = eval_preds
preds = [[i for i in seq if i != -100] for seq in preds]
labels = [[i for i in seq if i != -100] for seq in labels]
preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
correct = 0
total = 0
for pred, true in zip(preds, labels):
if pred.strip() == true.strip():
correct += 1
total += 1
accuracy = correct / total
return {"accuracy": accuracy}
training_args = Seq2SeqTrainingArguments(
"output",
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
learning_rate=lr,
num_train_epochs=num_epochs,
eval_strategy="epoch",
logging_strategy="epoch",
save_strategy="no",
report_to=[],
predict_with_generate=True,
generation_max_length=2,
remove_unused_columns=False,
)
trainer = Seq2SeqTrainer(
model=model,
processing_class=tokenizer,
args=training_args,
train_dataset=superglue_train_dataset,
eval_dataset=superglue_eval_dataset,
data_collator=default_data_collator,
compute_metrics=compute_metrics,
)
trainer.train()
# saving model
model_name_or_path = "google/flan-t5-xl"
peft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"
model.save_pretrained(peft_model_id)
!ls -lh $peft_model_id
device_type = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
device = f"{device_type}:0" if device_type != "cpu" else "cpu"
peft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, peft_model_id)
model = model.to(device)
model = model.eval()
i = 5
inputs = tokenizer(rte_dataset["validation"]["input"][i], return_tensors="pt")
inputs["task_ids"] = torch.LongTensor([TASK2ID["rte"]])
inputs = {k: v.to(device) for k, v in inputs.items()}
print(rte_dataset["validation"]["input"][i])
print(rte_dataset["validation"]["output"][i])
print(inputs)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=2)
print(outputs[0])
print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])