Back to Peft

Using PEFT with timm

examples/image_classification/image_classification_timm_peft_lora.ipynb

0.19.16.7 KB
Original Source

Using PEFT with timm

peft allows us to train any model with LoRA as long as the layer type is supported. Since Conv2D is one of the supported layer types, it makes sense to test it on image models.

In this short notebook, we will demonstrate this with an image classification task using timm.

Imports

Make sure that you have the latest version of peft installed. To ensure that, run this in your Python environment:

python -m pip install --upgrade peft

Also, ensure that timm is installed:

python -m pip install --upgrade timm
python
import timm
import torch
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
python
import peft
from datasets import load_dataset
python
torch.manual_seed(0)

Loading the pre-trained base model

We use a small pretrained timm model, PoolFormer. Find more info on its model card.

python
model_id_timm = "timm/poolformer_m36.sail_in1k"

We tell timm that we deal with 3 classes, to ensure that the classification layer has the correct size.

python
model = timm.create_model(model_id_timm, pretrained=True, num_classes=3)

These are the transformations steps necessary to process the image.

python
transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))

Data

For this exercise, we use the "beans" dataset. More details on the dataset can be found on its datasets page. For our purposes, what's important is that we have image inputs and the target we're trying to predict is one of three classes for each image.

python
ds = load_dataset("beans")
python
ds_train = ds["train"]
ds_valid = ds["validation"]
python
ds_train[0]["image"]

We define a small processing function which is responsible for loading and transforming the images, as well as extracting the labels.

python
def process(batch):
    x = torch.cat([transform(img).unsqueeze(0) for img in batch["image"]])
    y = torch.tensor(batch["labels"])
    return {"x": x, "y": y}
python
ds_train.set_transform(process)
ds_valid.set_transform(process)
python
train_loader = torch.utils.data.DataLoader(ds_train, batch_size=32)
valid_loader = torch.utils.data.DataLoader(ds_valid, batch_size=32)

Training

This is just a function that performs the train loop, nothing fancy happening.

python
def train(model, optimizer, criterion, train_dataloader, valid_dataloader, epochs):
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for batch in train_dataloader:
            xb, yb = batch["x"], batch["y"]
            xb, yb = xb.to(device), yb.to(device)
            outputs = model(xb)
            lsm = torch.nn.functional.log_softmax(outputs, dim=-1)
            loss = criterion(lsm, yb)
            train_loss += loss.detach().float()
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        model.eval()
        valid_loss = 0
        correct = 0
        n_total = 0
        for batch in valid_dataloader:
            xb, yb = batch["x"], batch["y"]
            xb, yb = xb.to(device), yb.to(device)
            with torch.no_grad():
                outputs = model(xb)
            lsm = torch.nn.functional.log_softmax(outputs, dim=-1)
            loss = criterion(lsm, yb)
            valid_loss += loss.detach().float()
            correct += (outputs.argmax(-1) == yb).sum().item()
            n_total += len(yb)

        train_loss_total = (train_loss / len(train_dataloader)).item()
        valid_loss_total = (valid_loss / len(valid_dataloader)).item()
        valid_acc_total = correct / n_total
        print(f"{epoch=:<2}  {train_loss_total=:.4f}  {valid_loss_total=:.4f}  {valid_acc_total=:.4f}")

Selecting which layers to fine-tune with LoRA

Let's take a look at the layers of our model. We only print the first 30, since there are quite a few:

python
[(n, type(m)) for n, m in model.named_modules()][:30]

Most of these layers are not good targets for LoRA, but we see a couple that should interest us. Their names are 'stages.0.blocks.0.mlp.fc1', etc. With a bit of regex, we can match them easily.

Also, we should inspect the name of the classification layer, since we want to train that one too!

python
[(n, type(m)) for n, m in model.named_modules()][-5:]
config = peft.LoraConfig(
    r=8,
    target_modules=r".*\.mlp\.fc\d|head\.fc",
)

Okay, this gives us all the information we need to fine-tune this model. With a bit of regex, we match the convolutional layers that should be targeted for LoRA. We also want to train the classification layer 'head.fc' (without LoRA), so we add it to the modules_to_save.

python
config = peft.LoraConfig(r=8, target_modules=r".*\.mlp\.fc\d", modules_to_save=["head.fc"])

Finally, let's create the peft model, the optimizer and criterion, and we can get started. As shown below, less than 2% of the model's total parameters are updated thanks to peft.

python
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
peft_model = peft.get_peft_model(model, config).to(device)
optimizer = torch.optim.Adam(peft_model.parameters(), lr=2e-4)
criterion = torch.nn.CrossEntropyLoss()
peft_model.print_trainable_parameters()
python
%time train(peft_model, optimizer, criterion, train_loader, valid_dataloader=valid_loader, epochs=10)

We get an accuracy of ~0.97, despite only training a tiny amount of parameters. That's a really nice result.

Sharing the model through Hugging Face Hub

Pushing the model to Hugging Face Hub

If we want to share the fine-tuned weights with the world, we can upload them to Hugging Face Hub like this:

python
user = "BenjaminB"  # put your user name here
model_name = "peft-lora-with-timm-model"
model_id = f"{user}/{model_name}"
python
peft_model.push_to_hub(model_id);

As we can see, the adapter size is only 4.3 MB. The original model was 225 MB. That's a very big saving.

Loading the model from HF Hub

Now, it only takes one step to load the model from HF Hub. To do this, we can use PeftModel.from_pretrained, passing our base model and the model ID:

python
base_model = timm.create_model(model_id_timm, pretrained=True, num_classes=3)
loaded = peft.PeftModel.from_pretrained(base_model, model_id)
python
x = ds_train[:1]["x"]
y_peft = peft_model(x.to(device))
y_loaded = loaded(x)
torch.allclose(y_peft.cpu(), y_loaded)

Clean up

Finally, as a clean up step, you may want to delete the repo.

python
from huggingface_hub import delete_repo
python
delete_repo(model_id)