examples/cpt_finetuning/cpt_train_and_inference.ipynb
This notebook demonstrates the training and evaluation process of Context-Aware Prompt Tuning (CPT) using the Hugging Face Trainer. For more details, refer to the Paper.
!pip install datasets
!pip install git+https://github.com/huggingface/peft
from typing import Any, Dict, List, Union
import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
from peft import CPTConfig, TaskType, get_peft_model
MAX_INPUT_LENGTH = 1024
MAX_ICL_SAMPLES = 10
NUM_TRAINING_SAMPLES = 100
model_id = 'bigscience/bloom-1b7'
# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_id, # The name or path of the pre-trained tokenizer (e.g., "bert-base-uncased").
cache_dir='.', # Directory to cache the tokenizer files locally.
padding_side='right', # Specifies that padding should be added to the right side of sequences.
trust_remote_code=True # Allows loading tokenizer implementations from external sources.
)
# Load the SST-2 dataset from the GLUE benchmark
dataset = load_dataset('glue', 'sst2')
def add_string_labels(example):
"""
Converts numerical labels into human-readable string labels.
Args:
example (dict): A single example from the dataset with a numerical 'label'.
Returns:
dict: The example augmented with a 'label_text' field.
"""
# Map numerical label to string label
example['label_text'] = "positive" if example['label'] == 1 else "negative"
return example
# Subset and process the training dataset
context_dataset = dataset['train'].select(range(MAX_ICL_SAMPLES)).map(add_string_labels)
train_dataset = dataset['train'].select(range(MAX_ICL_SAMPLES, NUM_TRAINING_SAMPLES + MAX_ICL_SAMPLES)).map(add_string_labels)
Note: This notebook uses small subsets of the dataset to ensure quick execution. For proper testing and evaluation, it is recommended to use the entire dataset by setting quick_review to False.
quick_review = True # set to False for a comprehensive evaluation
num_of_test_examples = 100 if quick_review else len(dataset['validation'])
# Subset and process the validation dataset
test_dataset = dataset['validation'].select(range(num_of_test_examples)).map(add_string_labels)
class CPTDataset(Dataset):
def __init__(self, samples, tokenizer, template, max_length=MAX_INPUT_LENGTH):
"""
Initialize the CPTDataset with samples, a tokenizer, and a template.
Args:
samples (list): List of samples containing input sentences and labels.
tokenizer: Tokenizer instance for encoding text.
template (dict): Dictionary defining input/output templates and separators.
max_length (int): Maximum input length for truncation.
"""
self.template = template
self.tokenizer = tokenizer
self.max_length = max_length
# Storage for tokenized inputs and masks
self.attention_mask = []
self.input_ids = []
self.input_type_mask = []
self.inter_seperator_ids = self._get_input_ids(template['inter_seperator'])
# Tokenize each sample and prepare inputs
for sample_i in tqdm(samples):
input_text, label = sample_i['sentence'], sample_i['label_text']
input_ids, attention_mask, input_type_mask = self.preprocess_sentence(input_text, label)
self.input_ids.append(input_ids)
self.attention_mask.append(attention_mask)
self.input_type_mask.append(input_type_mask)
def _get_input_ids(self, text):
"""
Tokenize the given text into input IDs.
Args:
text (str): The text to tokenize.
Returns:
list: Tokenized input IDs.
"""
return self.tokenizer(text, add_special_tokens=False)["input_ids"]
def preprocess_sentence(self, input_text, label):
"""
Preprocess a sentence and its corresponding label using templates.
Args:
input_text (str): The input sentence.
label (str): The label text (e.g., "positive", "negative").
Returns:
tuple: (input_ids, attention_mask, input_type_mask)
"""
# Split input template into parts
input_template_part_1_text, input_template_part_2_text = self.template['input'].split('{}')
input_template_tokenized_part1 = self._get_input_ids(input_template_part_1_text)
input_tokenized = self._get_input_ids(input_text)
input_template_tokenized_part2 = self._get_input_ids(input_template_part_2_text)
# Separator token
sep_tokenized = self._get_input_ids(self.template['intra_seperator'])
# Process the label using the template
label_template_part_1, label_template_part_2 = self.template['output'].split('{}')
label_template_part1_tokenized = self._get_input_ids(label_template_part_1)
label_tokenized = self._get_input_ids(label)
label_template_part2_tokenized = self._get_input_ids(label_template_part_2)
# End-of-sequence token
eos = [self.tokenizer.eos_token_id] if self.tokenizer.eos_token_id is not None else []
# Concatenate all tokenized parts
input_ids = input_template_tokenized_part1 + input_tokenized + input_template_tokenized_part2 + sep_tokenized + label_template_part1_tokenized + label_tokenized + label_template_part2_tokenized + eos
# Generate attention and type masks
attention_mask = [1] * len(input_ids)
input_type_mask = [1] * len(input_template_tokenized_part1) + [2] * len(input_tokenized) + [1] * len(
input_template_tokenized_part2) + [0] * len(sep_tokenized) + \
[3] * len(label_template_part1_tokenized) + [4] * len(label_tokenized) + [3] * len( \
label_template_part2_tokenized) + [0] * len(eos)
# Ensure all masks and inputs are the same length
assert len(input_type_mask) == len(input_ids) == len(attention_mask)
return input_ids, attention_mask, input_type_mask
def __len__(self):
"""
Get the number of examples in the dataset.
Returns:
int: Number of examples.
"""
return len(self.input_ids)
def __getitem__(self, idx):
"""
Get the tokenized representation for the given index.
Args:
idx (int): Index of the example.
Returns:
dict: Tokenized inputs with attention and type masks.
"""
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"input_type_mask": self.input_type_mask[idx]
}
# Define templates for tokenization
templates = {
'input': 'input: {}', # Input template with placeholder
'intra_seperator': ' ', # Separator between input and output
'output': 'output: {}', # Output template with placeholder
'inter_seperator': '\n' # Separator between examples
}
# Initialize the dataset
cpt_train_dataset = CPTDataset(train_dataset, tokenizer, templates)
# - `templates`: Define how inputs and outputs should be formatted and separated.
# - `CPTDataset`: Converts the raw dataset into tokenized input IDs and masks.
# Initialize storage for context-level information
context_ids = [] # Concatenated input IDs for all samples
context_attention_mask = [] # Concatenated attention masks
context_input_type_mask = [] # Concatenated input type masks
first_type_mask = 0 # Initial offset for input type mask
cpt_context_dataset = CPTDataset(context_dataset, tokenizer, templates)
# Iterate through the CPT training dataset
for i in range(len(context_dataset)):
# Add input IDs to the context
context_ids += cpt_context_dataset[i]['input_ids']
# Add attention mask to the context
context_attention_mask += cpt_context_dataset[i]['attention_mask']
# Adjust and add the input type mask to the context
context_input_type_mask += [
i + first_type_mask if i > 0 else 0 # Increment type indices dynamically
for i in cpt_context_dataset[i]['input_type_mask']
]
# Increment the type mask offset after processing the sample
first_type_mask += 4
# Load a pre-trained causal language model
base_model = AutoModelForCausalLM.from_pretrained(
model_id,
cache_dir='.',
dtype=torch.float16,
device_map='auto'
)
# Initialize the CPT configuration
config = CPTConfig(
task_type=TaskType.CAUSAL_LM,
cpt_token_ids=context_ids,
cpt_mask=context_attention_mask,
cpt_tokens_type_mask=context_input_type_mask,
opt_weighted_loss_type='decay',
opt_loss_decay_factor=0.95, # we choose the exponential decay factor applied to the loss
opt_projection_epsilon=0.2, # we choose the projection over the input tokens
opt_projection_format_epsilon=0.1, # we choose the projection over input and output templates
tokenizer_name_or_path=model_id,
)
# Initialize the CPT model with PEFT
model = get_peft_model(base_model, config)
class CPTDataCollatorForLanguageModeling(DataCollatorForLanguageModeling):
def __init__(self, tokenizer, training=True, mlm=False):
"""
Custom collator for CPT-style language modeling.
Args:
tokenizer: The tokenizer to handle tokenization and special tokens.
training (bool): If True, operates in training mode; otherwise in evaluation mode.
mlm (bool): If True, enables masked language modeling.
"""
super().__init__(tokenizer, mlm=mlm) # Initialize the parent class
self.training = training
# Add a special padding token if not already defined
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
"""
Process a batch of examples for language modeling.
Args:
examples (List): A batch of examples with tokenized inputs and optional sample masks.
Returns:
Dict: A dictionary containing padded and tensor-converted inputs, attention masks,
input type masks, and optional sample masks and labels.
"""
# Initialize a list to collect sample masks if provided
list_sample_mask = []
for i in range(len(examples)):
if "sample_mask" in examples[i].keys():
list_sample_mask.append(examples[i].pop("sample_mask"))
# Define a helper function for padding sequences to the maximum length
max_len = max(len(ex["input_ids"]) for ex in examples)
# Define a helper function for padding sequences to the maximum length
def pad_sequence(sequence, max_len, pad_value=0):
return sequence + [pad_value] * (max_len - len(sequence))
# Pad and convert `input_ids`, `attention_mask`, and `input_type_mask` to tensors
input_ids = torch.tensor([pad_sequence(ex["input_ids"], max_len) for ex in examples])
attention_mask = torch.tensor([pad_sequence(ex["attention_mask"], max_len) for ex in examples])
input_type_mask = torch.tensor([pad_sequence(ex["input_type_mask"], max_len) for ex in examples])
# Create the initial batch dictionary
batch = {"input_ids": input_ids, "attention_mask": attention_mask, "input_type_mask": input_type_mask}
# Create a tensor to store sample masks
tensor_sample_mask = batch["input_ids"].clone().long()
tensor_sample_mask[:, :] = 0 # Initialize with zeros
# Populate the tensor with the provided sample masks
for i in range(len(list_sample_mask)):
tensor_sample_mask[i, : len(list_sample_mask[i])] = list_sample_mask[i]
# Copy `input_ids` to use as `labels`
batch["labels"] = batch["input_ids"].clone()
# If in evaluation mode, include the `sample_mask` in the batch
if not self.training:
batch["sample_mask"] = tensor_sample_mask
return batch
training_args = TrainingArguments(
output_dir='../.',
use_cpu=False,
auto_find_batch_size=False,
learning_rate=1e-4,
logging_steps=100,
per_device_train_batch_size=1,
save_total_limit=1,
remove_unused_columns=False,
num_train_epochs=5,
fp16=True,
save_strategy='no',
report_to="none"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=cpt_train_dataset, # Custom CPT training dataset.
data_collator=CPTDataCollatorForLanguageModeling(tokenizer, training=True, mlm=False)
)
trainer.train()
model.eval()
# Select relevant columns from the test dataset
test_dataset = test_dataset.select_columns(['sentence', 'label_text'])
# Convert the test dataset to a CPT-compatible format
cpt_test_dataset = CPTDataset(test_dataset, tokenizer, templates)
# Get the device where the model is loaded (CPU, GPU or XPU)
device = model.device
list_bool_predictions = []
for i in range(len(test_dataset)):
input_ids, input_type_mask = cpt_test_dataset[i]['input_ids'], cpt_test_dataset[i]['input_type_mask']
# Pass the inputs through the model
outputs = model(
input_ids=torch.Tensor(input_ids).long().to(device=device).view(1, -1),
labels=torch.Tensor(input_ids).long().to(device=device).view(1, -1),
input_type_mask=torch.Tensor(input_type_mask).long().to(device=device).view(1, -1)
)
# Shift logits to exclude the last token and match the labels
shifted_logits = outputs.logits[..., :-1, :].contiguous().to(model.dtype)[0, -len(input_ids) + 1:]
shift_labels = torch.Tensor(input_ids).long().to(device=device).view(1, -1)[0, 1:].contiguous().to(device)
shifted_input_type_mask = torch.Tensor(input_type_mask).long().to(device=device).view(1, -1)[..., 1:].contiguous().to(device)
# Create a mask for the type `4` tokens (label tokens)
mask = torch.Tensor(shifted_input_type_mask).long().to(device=device).view(-1,) == 4
# Extract logits and labels corresponding to the mask
logit = shifted_logits[mask]
label = shift_labels[mask]
# All possible label tokens for `negative` and `positive`
all_labels = torch.Tensor([tokenizer(i, add_special_tokens=False)["input_ids"] for i in ['negative', 'positive']]).long().to(device).view(-1,)
# Compare logits with label tokens and infer prediction
prediction = logit[0, torch.Tensor([tokenizer(i, add_special_tokens=False)["input_ids"] for i in ['negative', 'positive']]).long().to(device).view(-1,)].argmax()
prediction_text = 'negative' if prediction == 0 else 'positive'
print(f"Sentence: {tokenizer.decode(input_ids)} \n \t The prediction is: {prediction_text}\n \t The GT is {tokenizer.decode(label)}")
list_bool_predictions.append(prediction_text == tokenizer.decode(label))
print(f'The model Acc is {100 * np.mean(list_bool_predictions)}%')