examples/sentence_transformer/training/clip/train_clip.ipynb
This notebook fine-tunes a CLIP model on sentence-transformers/unsplash-lite: ~25k Unsplash photos, each paired with descriptive keywords. We treat every (image, keywords) pair as a positive and train with MultipleNegativesRankingLoss, which uses the other captions in each batch as negatives. This is the standard contrastive objective for CLIP-style models. We hold out 100 pairs to track text-to-image retrieval quality during training.
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.sentence_transformer.evaluation import InformationRetrievalEvaluator
from sentence_transformers.sentence_transformer.losses import MultipleNegativesRankingLoss
from sentence_transformers.sentence_transformer.trainer import SentenceTransformerTrainer
from sentence_transformers.sentence_transformer.training_args import SentenceTransformerTrainingArguments
# Load CLIP model
model = SentenceTransformer("sentence-transformers/clip-ViT-B-32")
# Load the Unsplash Lite dataset: ~25k photos, each paired with descriptive keywords
dataset = load_dataset("sentence-transformers/unsplash-lite", split="train")
dataset
# Join the ";"-separated keywords into a comma-separated caption to use as the text side of each pair
def keywords_to_caption(batch):
return {"caption": [keywords.replace(";", ", ") for keywords in batch["keywords"]]}
dataset = dataset.map(keywords_to_caption, batched=True, remove_columns=["keywords"])
# MultipleNegativesRankingLoss reads columns in order, so (image, caption) becomes (anchor, positive)
dataset = dataset.select_columns(["image", "caption"])
# Hold out 100 pairs for evaluation
dataset = dataset.train_test_split(test_size=100, seed=42)
train_dataset, eval_dataset = dataset["train"], dataset["test"]
dataset
# Preview a training pair
sample = train_dataset[0]
print(sample["caption"])
sample["image"]
# Evaluate text-to-image retrieval on the held-out pairs: each caption should retrieve its own image
queries = {idx: sample["caption"] for idx, sample in enumerate(eval_dataset)}
corpus = {idx: sample["image"] for idx, sample in enumerate(eval_dataset)}
relevant_docs = {idx: [idx] for idx in range(len(eval_dataset))}
dev_evaluator = InformationRetrievalEvaluator(
queries=queries,
corpus=corpus,
relevant_docs=relevant_docs,
name="unsplash-dev",
)
# Baseline metrics for the un-finetuned model
dev_evaluator(model)
# Each image is matched against its own caption (positive) and the other captions in the batch (in-batch negatives)
train_loss = MultipleNegativesRankingLoss(model)
# Fine-tune CLIP on the image/caption pairs, evaluating retrieval every 10% of training
args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir="output/clip-unsplash-lite",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
learning_rate=2e-5,
warmup_steps=0.1,
fp16=False, # Set to True if your GPU does not support BF16
bf16=True, # Set to False if your GPU does not support BF16
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=0.1,
save_strategy="no",
logging_steps=0.05,
logging_first_step=True,
run_name="clip-unsplash-lite", # Used in W&B if `wandb` is installed
)
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=train_loss,
evaluator=dev_evaluator,
)
trainer.train()
# Save the fine-tuned model
model.save_pretrained("output/clip-unsplash-lite/final")
# Optionally, share it on the Hugging Face Hub (requires `huggingface-cli login`)
# model.push_to_hub("clip-unsplash-lite")
from IPython.display import clear_output, display
from sentence_transformers.util import semantic_search
# Re-embed the first 1,000 images with the fine-tuned model
images = train_dataset[:1000]["image"]
image_embeddings = model.encode(images, batch_size=32, convert_to_tensor=True, show_progress_bar=True)
# Retrieve the top 3 matching images for the typed keywords, until an empty line is entered
while True:
query = input("Search images by keywords (leave empty to stop): ")
if not query:
break
# Clear the previous results before showing the new ones
clear_output(wait=True)
query_embedding = model.encode(query, convert_to_tensor=True)
hits = semantic_search(query_embedding, image_embeddings, top_k=3)[0]
for hit in hits:
print(f"Score: {hit['score']:.3f}")
# Show each match at thumbnail size, preserving aspect ratio
thumbnail = images[hit["corpus_id"]].copy()
thumbnail.thumbnail((256, 256))
display(thumbnail)