Back to Peft

PEFT with DNA Language Models

examples/dna_language_models/dna_lm.ipynb

0.19.110.9 KB
Original Source

PEFT with DNA Language Models

This notebook demonstrates how to utilize parameter-efficient fine-tuning techniques (PEFT) from the PEFT library to fine-tune a DNA Language Model (DNA-LM). The fine-tuned DNA-LM will be applied to solve a task from the nucleotide benchmark dataset. Parameter-efficient fine-tuning (PEFT) techniques are crucial for adapting large pre-trained models to specific tasks with limited computational resources.

1. Import relevant libraries

We'll start by importing the required libraries, including the PEFT library and other dependencies.

python
import torch
import transformers
import peft
import tqdm
import numpy as np

2. Load models

We'll load a pre-trained DNA Language Model, "SpeciesLM", that serves as the base for fine-tuning. This is done using the transformers library from HuggingFace.

The tokenizer and the model comes from the paper, "Species-aware DNA language models capture regulatory elements and their evolution". Paper Link, Code Link. They introduce a species-aware DNA language model, which is trained on more than 800 species spanning over 500 million years of evolution.

python
from transformers import AutoTokenizer, AutoModelForMaskedLM
python
tokenizer = AutoTokenizer.from_pretrained("gagneurlab/SpeciesLM", revision = "downstream_species_lm")
lm = AutoModelForMaskedLM.from_pretrained("gagneurlab/SpeciesLM", revision = "downstream_species_lm")
python
lm.eval()
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
lm.to(device);

2. Prepare datasets

We'll load the nucleotide_transformer_downstream_tasks dataset, which contains 18 downstream tasks from the Nucleotide Transformer paper. This dataset provides a consistent genomics benchmark with binary classification tasks.

python
from datasets import load_dataset

raw_data_full = load_dataset("InstaDeepAI/nucleotide_transformer_downstream_tasks")
raw_data = raw_data_full.filter(lambda example: example['task'] == 'H3')

We'll use the "H3" subset of this dataset, which contains a total of 13,468 rows in the training data, and 1497 rows in the test data.

python
raw_data

The dataset consists of three columns, sequence, name and label. An row in this dataset looks like:

python
raw_data['train'][0]

We split out dataset into training, test, and validation sets.

python
from datasets import Dataset, DatasetDict

train_valid_split = raw_data['train'].train_test_split(test_size=0.15, seed=42)

train_valid_split = DatasetDict({
    'train': train_valid_split['train'],
    'validation': train_valid_split['test']
})

ds = DatasetDict({
    'train': train_valid_split['train'],
    'validation': train_valid_split['validation'],
    'test': raw_data['test']
})

Then, we use the tokenizer and a utility function we created, get_kmers to generate the final data and labels. The get_kmers function is essential for generating overlapping 6-mers needed by the language model (LM). By using k=6 and stride=1, we ensure that the model receives continuous and overlapping subsequences, capturing the local context within the biological sequence for more effective analysis and prediction.

python
def get_kmers(seq, k=6, stride=1):
    return [seq[i:i + k] for i in range(0, len(seq), stride) if i + k <= len(seq)]
python
test_sequences = []
train_sequences = []
val_sequences = []

dataset_limit = 200 # NOTE: This dataset limit is set to 200, so that the training runs faster. It can be set to None to use the
                    # entire dataset

for i in range(0, len(ds['train'])):

    if dataset_limit and i == dataset_limit:
        break

    sequence = ds['train'][i]['sequence']
    sequence = "candida_glabrata " + " ".join(get_kmers(sequence))
    sequence = tokenizer(sequence)["input_ids"]
    train_sequences.append(sequence)


for i in range(0, len(ds['validation'])):
    if dataset_limit and i == dataset_limit:
        break
    sequence = ds['validation'][i]['sequence']
    sequence = "candida_glabrata " + " ".join(get_kmers(sequence))
    sequence = tokenizer(sequence)["input_ids"]
    val_sequences.append(sequence)


for i in range(0, len(ds['test'])):
    if dataset_limit and i == dataset_limit:
        break
    sequence = ds['test'][i]['sequence']
    sequence = "candida_glabrata " + " ".join(get_kmers(sequence))
    sequence = tokenizer(sequence)["input_ids"]
    test_sequences.append(sequence)


train_labels = ds['train']['label']
test_labels = ds['test']['label']
val_labels = ds['validation']['label']

if dataset_limit:
    train_labels = train_labels[0:dataset_limit]
    test_labels = test_labels[0:dataset_limit]
    val_labels = val_labels[0:dataset_limit]

Finally, we create a Dataset object for each our sets.

python
from datasets import Dataset

train_dataset = Dataset.from_dict({"input_ids": train_sequences, "labels": train_labels})
val_dataset = Dataset.from_dict({"input_ids": val_sequences, "labels": val_labels})
test_dataset = Dataset.from_dict({"input_ids": test_sequences, "labels": test_labels})

4. Train model

Now, we'll train our DNA Language Model with the training dataset. We'll add a linear layer in the final layer of our language model, and then, train all the parameteres of our model with the training dataset.

python
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
python
import torch
from torch import nn

class DNA_LM(nn.Module):
    def __init__(self, model, num_labels):
        super(DNA_LM, self).__init__()
        self.model = model.bert
        self.in_features = model.config.hidden_size
        self.out_features = num_labels
        self.classifier = nn.Linear(self.in_features, self.out_features)

    def forward(self, input_ids, attention_mask=None, labels=None):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        sequence_output = outputs.hidden_states[-1]
        # Use the [CLS] token for classification
        cls_output = sequence_output[:, 0, :]
        logits = self.classifier(cls_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.out_features), labels.view(-1))

        return (loss, logits) if loss is not None else logits

# Number of classes for your classification task
num_labels = 2
classification_model = DNA_LM(lm, num_labels)
classification_model.to(device);
python
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
python
from transformers import Trainer, TrainingArguments


# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
    eval_steps=1,
    logging_steps=1,
)

# Initialize Trainer
trainer = Trainer(
    model=classification_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

# Train the model
trainer.train()

5. Evaluation

python
# Generate predictions

predictions = trainer.predict(test_dataset)
logits = predictions.predictions
predicted_labels = logits.argmax(axis=-1)
print(predicted_labels)

Then, we create a function to calculate the accuracy from the test and predicted labels.

python
def calculate_accuracy(true_labels, predicted_labels):

    assert len(true_labels) == len(predicted_labels), "Arrays must have the same length"
    correct_predictions = np.sum(true_labels == predicted_labels)
    accuracy = correct_predictions / len(true_labels)

    return accuracy

accuracy = calculate_accuracy(test_labels, predicted_labels)
print(f"Accuracy: {accuracy:.2f}")

The results aren't that good, which we can attribute to the small dataset size.

7. Parameter Efficient Fine-Tuning Techniques

In this section, we demonstrate how to employ parameter-efficient fine-tuning (PEFT) techniques to adapt a pre-trained model for specific genomics tasks using the PEFT library.

The LoraConfig object is instantiated to configure the PEFT parameters:

  • task_type: Specifies the type of task, in this case, sequence classification (SEQ_CLS).
  • r: The rank of the LoRA matrices.
  • lora_alpha: Scaling factor for adaptive re-parameterization.
  • target_modules: Modules within the model to apply PEFT re-parameterization (query, key, value in this example).
  • lora_dropout: Dropout rate used during PEFT fine-tuning.
python
# Number of classes for your classification task
num_labels = 2
classification_model = DNA_LM(lm, num_labels)
classification_model.to(device);
python
from peft import LoraConfig, TaskType

peft_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["query", "key", "value"],
    lora_dropout=0.01,
)
python
from peft import get_peft_model

peft_model = get_peft_model(classification_model, peft_config)
peft_model.print_trainable_parameters()
python
peft_model
python
# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
    eval_steps=1,
    logging_steps=1,
)

# Initialize Trainer
trainer = Trainer(
    model=peft_model.model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

# Train the model
trainer.train()

8. Evaluate PEFT Model

python
# Generate predictions

predictions = trainer.predict(test_dataset)
logits = predictions.predictions
predicted_labels = logits.argmax(axis=-1)
print(predicted_labels)
python
def calculate_accuracy(true_labels, predicted_labels):

    assert len(true_labels) == len(predicted_labels), "Arrays must have the same length"
    correct_predictions = np.sum(true_labels == predicted_labels)
    accuracy = correct_predictions / len(true_labels)

    return accuracy

accuracy = calculate_accuracy(test_labels, predicted_labels)
print(f"Accuracy: {accuracy:.2f}")

As we can see, the PEFT model achieves similar performance to the baseline model, demonstrating the effectiveness of PEFT in adapting pre-trained models to specific tasks with limited computational resources.

With PEFT, we only train 442,368 parameters, which is 0.49% of the total parameters in the model. This is a significant reduction in computational resources compared to training the entire model from scratch.

We can improve the results by using a larger dataset, fine-tuning the model for more epochs or changing the hyperparameters (rank, learning rate, etc.).