examples/sentence_transformer/training/clip/train_clip.ipynb
from PIL import Image
from torch.utils.data import DataLoader
from sentence_transformers import InputExample, SentenceTransformer, losses, util
# Load CLIP model
model = SentenceTransformer("sentence-transformers/clip-ViT-B-32")
import os
import zipfile
from tqdm import tqdm
# Next, we get about 25k images from Unsplash
img_folder = "photos/"
if not os.path.exists(img_folder) or len(os.listdir(img_folder)) == 0:
os.makedirs(img_folder, exist_ok=True)
photo_filename = "unsplash-25k-photos.zip"
if not os.path.exists(photo_filename): # Download dataset if does not exist
util.http_get("http://sbert.net/datasets/" + photo_filename, photo_filename)
# Extract 1000 images
with zipfile.ZipFile(photo_filename, "r") as zf:
for member in tqdm(zf.infolist()[:1000], desc="Extracting"):
zf.extract(member, img_folder)
# Load each of the images
photos = [Image.open(os.path.join(img_folder, photo_path)) for photo_path in os.listdir(img_folder)]
photos[0]
train_dataset = []
for idx in range(0, len(photos), 2):
# We can use image pairs directly. Because our images aren't labeled, we use a random label as an example
# train_dataset.append(InputExample(texts=[photos[idx], photos[idx + 1]], label=random.choice([0, 1])))
# Or images and text together
train_dataset.append(InputExample(texts=[photos[idx], "This is the caption"], label=1))
train_dataset.append(InputExample(texts=[photos[idx], "This is another unrelated caption"], label=0))
# Or just texts
# train_dataset.append(InputExample(texts=["This is a caption", "This is a similar caption"], label=1))
# train_dataset.append(InputExample(texts=["This is a caption", "This is an unrelated caption"], label=0))
# We'll create a DataLoader that batches our data and prepare a contrastive loss function
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=4)
train_loss = losses.ContrastiveLoss(model=model)
# Now we can fine-tune our model on the labeled image pairs
model.fit([(train_dataloader, train_loss)], epochs=5, show_progress_bar=True)