Back to Peft

CPT Training and Inference

examples/cpt_finetuning/cpt_train_and_inference.ipynb

0.19.115.5 KB
Original Source

CPT Training and Inference

This notebook demonstrates the training and evaluation process of Context-Aware Prompt Tuning (CPT) using the Hugging Face Trainer. For more details, refer to the Paper.

Sections Overview:

  1. Setup: Import libraries and configure the environment.
  2. Data Preparation: Load and preprocess the dataset.
  3. Model Training: Configure and train the model.
  4. Evaluation: Test the model's performance and visualize results.

Setup


Installation

python
!pip install datasets
!pip install git+https://github.com/huggingface/peft

Imports

python
from typing import Any, Dict, List, Union

import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)

from peft import CPTConfig, TaskType, get_peft_model


MAX_INPUT_LENGTH = 1024
MAX_ICL_SAMPLES = 10
NUM_TRAINING_SAMPLES = 100
model_id = 'bigscience/bloom-1b7'

Data Preparation


python
# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    model_id,               # The name or path of the pre-trained tokenizer (e.g., "bert-base-uncased").
    cache_dir='.',          # Directory to cache the tokenizer files locally.
    padding_side='right',   # Specifies that padding should be added to the right side of sequences.
    trust_remote_code=True  # Allows loading tokenizer implementations from external sources.
)
python
# Load the SST-2 dataset from the GLUE benchmark
dataset = load_dataset('glue', 'sst2')

def add_string_labels(example):
    """
    Converts numerical labels into human-readable string labels.

    Args:
        example (dict): A single example from the dataset with a numerical 'label'.

    Returns:
        dict: The example augmented with a 'label_text' field.
    """
    # Map numerical label to string label
    example['label_text'] = "positive" if example['label'] == 1 else "negative"
    return example

# Subset and process the training dataset
context_dataset = dataset['train'].select(range(MAX_ICL_SAMPLES)).map(add_string_labels)
train_dataset = dataset['train'].select(range(MAX_ICL_SAMPLES, NUM_TRAINING_SAMPLES + MAX_ICL_SAMPLES)).map(add_string_labels)

Note: This notebook uses small subsets of the dataset to ensure quick execution. For proper testing and evaluation, it is recommended to use the entire dataset by setting quick_review to False.

python
quick_review = True # set to False for a comprehensive evaluation
num_of_test_examples = 100 if quick_review else len(dataset['validation'])
# Subset and process the validation dataset
test_dataset = dataset['validation'].select(range(num_of_test_examples)).map(add_string_labels)
python
class CPTDataset(Dataset):
    def __init__(self, samples, tokenizer, template, max_length=MAX_INPUT_LENGTH):
        """
        Initialize the CPTDataset with samples, a tokenizer, and a template.

        Args:
            samples (list): List of samples containing input sentences and labels.
            tokenizer: Tokenizer instance for encoding text.
            template (dict): Dictionary defining input/output templates and separators.
            max_length (int): Maximum input length for truncation.
        """
        self.template = template
        self.tokenizer = tokenizer
        self.max_length = max_length

        # Storage for tokenized inputs and masks
        self.attention_mask = []
        self.input_ids = []
        self.input_type_mask = []
        self.inter_seperator_ids = self._get_input_ids(template['inter_seperator'])

        # Tokenize each sample and prepare inputs
        for sample_i in tqdm(samples):
            input_text, label = sample_i['sentence'], sample_i['label_text']
            input_ids, attention_mask, input_type_mask = self.preprocess_sentence(input_text, label)

            self.input_ids.append(input_ids)
            self.attention_mask.append(attention_mask)
            self.input_type_mask.append(input_type_mask)


    def _get_input_ids(self, text):
        """
        Tokenize the given text into input IDs.

        Args:
            text (str): The text to tokenize.

        Returns:
            list: Tokenized input IDs.
        """
        return self.tokenizer(text, add_special_tokens=False)["input_ids"]


    def preprocess_sentence(self, input_text, label):
        """
        Preprocess a sentence and its corresponding label using templates.

        Args:
            input_text (str): The input sentence.
            label (str): The label text (e.g., "positive", "negative").

        Returns:
            tuple: (input_ids, attention_mask, input_type_mask)
        """

        # Split input template into parts
        input_template_part_1_text, input_template_part_2_text = self.template['input'].split('{}')
        input_template_tokenized_part1 = self._get_input_ids(input_template_part_1_text)
        input_tokenized = self._get_input_ids(input_text)
        input_template_tokenized_part2 = self._get_input_ids(input_template_part_2_text)

        # Separator token
        sep_tokenized = self._get_input_ids(self.template['intra_seperator'])

        # Process the label using the template
        label_template_part_1, label_template_part_2 = self.template['output'].split('{}')
        label_template_part1_tokenized = self._get_input_ids(label_template_part_1)
        label_tokenized = self._get_input_ids(label)
        label_template_part2_tokenized = self._get_input_ids(label_template_part_2)

        # End-of-sequence token
        eos = [self.tokenizer.eos_token_id] if self.tokenizer.eos_token_id is not None else []

        # Concatenate all tokenized parts
        input_ids = input_template_tokenized_part1 + input_tokenized + input_template_tokenized_part2 + sep_tokenized + label_template_part1_tokenized + label_tokenized + label_template_part2_tokenized + eos

        # Generate attention and type masks
        attention_mask = [1] * len(input_ids)
        input_type_mask = [1] * len(input_template_tokenized_part1) + [2] * len(input_tokenized) + [1] * len(
            input_template_tokenized_part2) + [0] * len(sep_tokenized) + \
                          [3] * len(label_template_part1_tokenized) + [4] * len(label_tokenized) + [3] * len( \
            label_template_part2_tokenized) + [0] * len(eos)

        # Ensure all masks and inputs are the same length
        assert len(input_type_mask) == len(input_ids) == len(attention_mask)

        return input_ids, attention_mask, input_type_mask


    def __len__(self):
        """
        Get the number of examples in the dataset.

        Returns:
            int: Number of examples.
        """
        return len(self.input_ids)


    def __getitem__(self, idx):
        """
        Get the tokenized representation for the given index.

        Args:
            idx (int): Index of the example.

        Returns:
            dict: Tokenized inputs with attention and type masks.
        """

        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "input_type_mask": self.input_type_mask[idx]
        }

# Define templates for tokenization
templates = {
    'input': 'input: {}',     # Input template with placeholder
    'intra_seperator': ' ',   # Separator between input and output
    'output': 'output: {}',   # Output template with placeholder
    'inter_seperator': '\n'   # Separator between examples
}

# Initialize the dataset
cpt_train_dataset = CPTDataset(train_dataset, tokenizer, templates)


# - `templates`: Define how inputs and outputs should be formatted and separated.
# - `CPTDataset`: Converts the raw dataset into tokenized input IDs and masks.
python
# Initialize storage for context-level information
context_ids = []                # Concatenated input IDs for all samples
context_attention_mask = []     # Concatenated attention masks
context_input_type_mask = []    # Concatenated input type masks
first_type_mask = 0             # Initial offset for input type mask

cpt_context_dataset = CPTDataset(context_dataset, tokenizer, templates)

# Iterate through the CPT training dataset
for i in range(len(context_dataset)):
    # Add input IDs to the context
    context_ids += cpt_context_dataset[i]['input_ids']

    # Add attention mask to the context
    context_attention_mask += cpt_context_dataset[i]['attention_mask']

    # Adjust and add the input type mask to the context
    context_input_type_mask += [
        i + first_type_mask if i > 0 else 0 # Increment type indices dynamically
        for i in cpt_context_dataset[i]['input_type_mask']
        ]

    # Increment the type mask offset after processing the sample
    first_type_mask += 4

Model Training


Load model

python
# Load a pre-trained causal language model
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    cache_dir='.',
    dtype=torch.float16,
    device_map='auto'
)

# Initialize the CPT configuration
config = CPTConfig(
            task_type=TaskType.CAUSAL_LM,
            cpt_token_ids=context_ids,
            cpt_mask=context_attention_mask,
            cpt_tokens_type_mask=context_input_type_mask,

            opt_weighted_loss_type='decay',
            opt_loss_decay_factor=0.95,         # we choose the exponential decay factor applied to the loss
            opt_projection_epsilon=0.2,         # we choose the projection over the input tokens
            opt_projection_format_epsilon=0.1,  # we choose the projection over input and output templates

            tokenizer_name_or_path=model_id,
)

# Initialize the CPT model with PEFT
model = get_peft_model(base_model, config)

Setting Collate Function

python
class CPTDataCollatorForLanguageModeling(DataCollatorForLanguageModeling):
    def __init__(self, tokenizer, training=True, mlm=False):
        """
        Custom collator for CPT-style language modeling.

        Args:
            tokenizer: The tokenizer to handle tokenization and special tokens.
            training (bool): If True, operates in training mode; otherwise in evaluation mode.
            mlm (bool): If True, enables masked language modeling.
        """

        super().__init__(tokenizer, mlm=mlm) # Initialize the parent class
        self.training = training

        # Add a special padding token if not already defined
        self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})

    def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
        """
        Process a batch of examples for language modeling.

        Args:
            examples (List): A batch of examples with tokenized inputs and optional sample masks.

        Returns:
            Dict: A dictionary containing padded and tensor-converted inputs, attention masks,
                  input type masks, and optional sample masks and labels.
        """

        # Initialize a list to collect sample masks if provided
        list_sample_mask = []
        for i in range(len(examples)):
            if "sample_mask" in examples[i].keys():
                list_sample_mask.append(examples[i].pop("sample_mask"))

        # Define a helper function for padding sequences to the maximum length
        max_len = max(len(ex["input_ids"]) for ex in examples)

        # Define a helper function for padding sequences to the maximum length
        def pad_sequence(sequence, max_len, pad_value=0):
            return sequence + [pad_value] * (max_len - len(sequence))

        # Pad and convert `input_ids`, `attention_mask`, and `input_type_mask` to tensors
        input_ids = torch.tensor([pad_sequence(ex["input_ids"], max_len) for ex in examples])
        attention_mask = torch.tensor([pad_sequence(ex["attention_mask"], max_len) for ex in examples])
        input_type_mask = torch.tensor([pad_sequence(ex["input_type_mask"], max_len) for ex in examples])

        # Create the initial batch dictionary
        batch = {"input_ids": input_ids, "attention_mask": attention_mask, "input_type_mask": input_type_mask}

        # Create a tensor to store sample masks
        tensor_sample_mask = batch["input_ids"].clone().long()
        tensor_sample_mask[:, :] = 0 # Initialize with zeros

        # Populate the tensor with the provided sample masks
        for i in range(len(list_sample_mask)):
            tensor_sample_mask[i, : len(list_sample_mask[i])] = list_sample_mask[i]

        # Copy `input_ids` to use as `labels`
        batch["labels"] = batch["input_ids"].clone()

        # If in evaluation mode, include the `sample_mask` in the batch
        if not self.training:
            batch["sample_mask"] = tensor_sample_mask

        return batch

Training

python
training_args = TrainingArguments(
    output_dir='../.',
    use_cpu=False,
    auto_find_batch_size=False,
    learning_rate=1e-4,
    logging_steps=100,
    per_device_train_batch_size=1,
    save_total_limit=1,
    remove_unused_columns=False,
    num_train_epochs=5,
    fp16=True,
    save_strategy='no',
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=cpt_train_dataset,  # Custom CPT training dataset.
    data_collator=CPTDataCollatorForLanguageModeling(tokenizer, training=True, mlm=False)
)

trainer.train()

Model Evaluation


python
model.eval()

# Select relevant columns from the test dataset
test_dataset = test_dataset.select_columns(['sentence', 'label_text'])

# Convert the test dataset to a CPT-compatible format
cpt_test_dataset = CPTDataset(test_dataset, tokenizer, templates)

# Get the device where the model is loaded (CPU, GPU or XPU)
device = model.device
list_bool_predictions = []

for i in range(len(test_dataset)):
    input_ids, input_type_mask = cpt_test_dataset[i]['input_ids'], cpt_test_dataset[i]['input_type_mask']

    # Pass the inputs through the model
    outputs = model(
        input_ids=torch.Tensor(input_ids).long().to(device=device).view(1, -1),
        labels=torch.Tensor(input_ids).long().to(device=device).view(1, -1),
        input_type_mask=torch.Tensor(input_type_mask).long().to(device=device).view(1, -1)
    )

    # Shift logits to exclude the last token and match the labels
    shifted_logits = outputs.logits[..., :-1, :].contiguous().to(model.dtype)[0, -len(input_ids) + 1:]
    shift_labels = torch.Tensor(input_ids).long().to(device=device).view(1, -1)[0, 1:].contiguous().to(device)
    shifted_input_type_mask = torch.Tensor(input_type_mask).long().to(device=device).view(1, -1)[..., 1:].contiguous().to(device)

    # Create a mask for the type `4` tokens (label tokens)
    mask = torch.Tensor(shifted_input_type_mask).long().to(device=device).view(-1,) == 4

    # Extract logits and labels corresponding to the mask
    logit = shifted_logits[mask]
    label = shift_labels[mask]

    # All possible label tokens for `negative` and `positive`
    all_labels = torch.Tensor([tokenizer(i, add_special_tokens=False)["input_ids"] for i in ['negative', 'positive']]).long().to(device).view(-1,)

    # Compare logits with label tokens and infer prediction
    prediction = logit[0, torch.Tensor([tokenizer(i, add_special_tokens=False)["input_ids"] for i in ['negative', 'positive']]).long().to(device).view(-1,)].argmax()
    prediction_text = 'negative' if prediction == 0 else 'positive'
    print(f"Sentence: {tokenizer.decode(input_ids)} \n \t The prediction is: {prediction_text}\n \t The GT is {tokenizer.decode(label)}")
    list_bool_predictions.append(prediction_text == tokenizer.decode(label))

print(f'The model Acc is {100 * np.mean(list_bool_predictions)}%')