Back to Sentence Transformers

Simple image training using CLIP model

examples/sentence_transformer/training/clip/train_clip.ipynb

5.6.05.0 KB
Original Source

Simple image training using CLIP model

This notebook fine-tunes a CLIP model on sentence-transformers/unsplash-lite: ~25k Unsplash photos, each paired with descriptive keywords. We treat every (image, keywords) pair as a positive and train with MultipleNegativesRankingLoss, which uses the other captions in each batch as negatives. This is the standard contrastive objective for CLIP-style models. We hold out 100 pairs to track text-to-image retrieval quality during training.

python
from datasets import load_dataset

from sentence_transformers import SentenceTransformer
from sentence_transformers.sentence_transformer.evaluation import InformationRetrievalEvaluator
from sentence_transformers.sentence_transformer.losses import MultipleNegativesRankingLoss
from sentence_transformers.sentence_transformer.trainer import SentenceTransformerTrainer
from sentence_transformers.sentence_transformer.training_args import SentenceTransformerTrainingArguments
python
# Load CLIP model
model = SentenceTransformer("sentence-transformers/clip-ViT-B-32")
python
# Load the Unsplash Lite dataset: ~25k photos, each paired with descriptive keywords
dataset = load_dataset("sentence-transformers/unsplash-lite", split="train")
dataset
python
# Join the ";"-separated keywords into a comma-separated caption to use as the text side of each pair
def keywords_to_caption(batch):
    return {"caption": [keywords.replace(";", ", ") for keywords in batch["keywords"]]}


dataset = dataset.map(keywords_to_caption, batched=True, remove_columns=["keywords"])
# MultipleNegativesRankingLoss reads columns in order, so (image, caption) becomes (anchor, positive)
dataset = dataset.select_columns(["image", "caption"])

# Hold out 100 pairs for evaluation
dataset = dataset.train_test_split(test_size=100, seed=42)
train_dataset, eval_dataset = dataset["train"], dataset["test"]
dataset
python
# Preview a training pair
sample = train_dataset[0]
print(sample["caption"])
sample["image"]
python
# Evaluate text-to-image retrieval on the held-out pairs: each caption should retrieve its own image
queries = {idx: sample["caption"] for idx, sample in enumerate(eval_dataset)}
corpus = {idx: sample["image"] for idx, sample in enumerate(eval_dataset)}
relevant_docs = {idx: [idx] for idx in range(len(eval_dataset))}
dev_evaluator = InformationRetrievalEvaluator(
    queries=queries,
    corpus=corpus,
    relevant_docs=relevant_docs,
    name="unsplash-dev",
)

# Baseline metrics for the un-finetuned model
dev_evaluator(model)
python
# Each image is matched against its own caption (positive) and the other captions in the batch (in-batch negatives)
train_loss = MultipleNegativesRankingLoss(model)
python
# Fine-tune CLIP on the image/caption pairs, evaluating retrieval every 10% of training
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir="output/clip-unsplash-lite",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    learning_rate=2e-5,
    warmup_steps=0.1,
    fp16=False,  # Set to True if your GPU does not support BF16
    bf16=True,  # Set to False if your GPU does not support BF16
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=0.1,
    save_strategy="no",
    logging_steps=0.05,
    logging_first_step=True,
    run_name="clip-unsplash-lite",  # Used in W&B if `wandb` is installed
)
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=train_loss,
    evaluator=dev_evaluator,
)
trainer.train()
python
# Save the fine-tuned model
model.save_pretrained("output/clip-unsplash-lite/final")

# Optionally, share it on the Hugging Face Hub (requires `huggingface-cli login`)
# model.push_to_hub("clip-unsplash-lite")
python
from IPython.display import clear_output, display

from sentence_transformers.util import semantic_search

# Re-embed the first 1,000 images with the fine-tuned model
images = train_dataset[:1000]["image"]
image_embeddings = model.encode(images, batch_size=32, convert_to_tensor=True, show_progress_bar=True)

# Retrieve the top 3 matching images for the typed keywords, until an empty line is entered
while True:
    query = input("Search images by keywords (leave empty to stop): ")
    if not query:
        break
    # Clear the previous results before showing the new ones
    clear_output(wait=True)
    query_embedding = model.encode(query, convert_to_tensor=True)
    hits = semantic_search(query_embedding, image_embeddings, top_k=3)[0]
    for hit in hits:
        print(f"Score: {hit['score']:.3f}")
        # Show each match at thumbnail size, preserving aspect ratio
        thumbnail = images[hit["corpus_id"]].copy()
        thumbnail.thumbnail((256, 256))
        display(thumbnail)