From 5f23a76d4c7302da90643219996b2b7099c926dd Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Fri, 12 Jan 2024 13:36:23 +0100 Subject: [PATCH] Support winoground (#116) * support winoground. re-use image_caption_selection for both sugar crepe and winoground. * minor --- README.md | 6 ++ clip_benchmark/datasets/builder.py | 8 ++- clip_benchmark/datasets/winoground.py | 30 +++++++++ .../metrics/image_caption_selection.py | 65 ++++++++++++------- 4 files changed, 84 insertions(+), 25 deletions(-) create mode 100644 clip_benchmark/datasets/winoground.py diff --git a/README.md b/README.md index 8ab90f7..4175571 100644 --- a/README.md +++ b/README.md @@ -240,6 +240,12 @@ To evaluate on all the tasks together, you can do: `clip_benchmark eval --model ViT-B-32 --pretrained laion400m_e32 --dataset=sugar_crepe --output=result.json` +For [winoground](https://huggingface.co/datasets/facebook/winoground/): + +`clip_benchmark eval --model ViT-B-32 --pretrained laion400m_e32 --dataset=winoground --output=result.json` + +NB: `pip install datasets` is required for winoground. + ### Webdataset example Here is an example on how to run it on [webdatasets](https://github.com/webdataset/webdataset). diff --git a/clip_benchmark/datasets/builder.py b/clip_benchmark/datasets/builder.py index 14c1e56..96d5115 100644 --- a/clip_benchmark/datasets/builder.py +++ b/clip_benchmark/datasets/builder.py @@ -13,7 +13,7 @@ RenderedSST2, StanfordCars) from . import (babel_imagenet, caltech101, flickr, imagenetv2, objectnet, - sugar_crepe, voc2007) + sugar_crepe, voc2007, winoground) def build_dataset(dataset_name, root="root", transform=None, split="test", download=True, annotation_file=None, language="en", task="zeroshot_classification", wds_cache_dir=None, custom_classname_file=None, custom_template_file=None, **kwargs): @@ -242,6 +242,8 @@ def download_imagenet(r): url = f"https://raw.githubusercontent.com/RAIVNLab/sugar-crepe/main/data/{task}.json" call(f"wget {url} --output-document={ann}", shell=True) ds = sugar_crepe.SugarCrepe(root=os.path.join(root, "val2017"), ann_file=ann, transform=transform, **kwargs) + elif dataset_name == "winoground": + ds = winoground.WinoGround(root=root, transform=transform) elif dataset_name == "mscoco_captions": # https://github.com/mehdidc/retrieval_annotations/releases/tag/1.0.0(annotations) if split == "train": @@ -523,13 +525,13 @@ def __len__(self): def get_dataset_default_task(dataset): if dataset in ("flickr30k", "flickr8k", "mscoco_captions", "multilingual_mscoco_captions", "flickr30k-200", "crossmodal3600", "xtd200"): return "zeroshot_retrieval" - elif dataset.startswith("sugar_crepe"): + elif dataset.startswith("sugar_crepe") or dataset == "winoground": return "image_caption_selection" else: return "zeroshot_classification" def get_dataset_collate_fn(dataset_name): - if dataset_name in ("mscoco_captions", "multilingual_mscoco_captions", "flickr30k", "flickr8k", "flickr30k-200", "crossmodal3600", "xtd200") or dataset_name.startswith("sugar_crepe"): + if dataset_name in ("mscoco_captions", "multilingual_mscoco_captions", "flickr30k", "flickr8k", "flickr30k-200", "crossmodal3600", "xtd200", "winoground") or dataset_name.startswith("sugar_crepe"): return image_captions_collate_fn else: return default_collate diff --git a/clip_benchmark/datasets/winoground.py b/clip_benchmark/datasets/winoground.py new file mode 100644 index 0000000..b09661b --- /dev/null +++ b/clip_benchmark/datasets/winoground.py @@ -0,0 +1,30 @@ +import os +from torch.utils.data import Dataset +from PIL import Image +import torch +import json + +class WinoGround(Dataset): + + def __init__(self, root=".", transform=None): + from datasets import load_dataset + self.ds = load_dataset("facebook/winoground", cache_dir=root)["test"] + self.transform = transform + + def __getitem__(self, idx): + data = self.ds[idx] + img0 = data["image_0"] + img1 = data["image_1"] + cap0 = data["caption_0"] + cap1 = data["caption_1"] + if self.transform is not None: + img0 = self.transform(img0) + img1 = self.transform(img1) + imgs = torch.stack([img0, img1]) + else: + imgs = [img0, img1] + caps = [cap0, cap1] + return imgs, caps + + def __len__(self): + return len(self.ds) \ No newline at end of file diff --git a/clip_benchmark/metrics/image_caption_selection.py b/clip_benchmark/metrics/image_caption_selection.py index 61f6b5a..85b2eb3 100644 --- a/clip_benchmark/metrics/image_caption_selection.py +++ b/clip_benchmark/metrics/image_caption_selection.py @@ -5,9 +5,13 @@ import torch.nn.functional as F from tqdm import tqdm -def evaluate(model, dataloader, tokenizer, device, amp=True, recall_k_list=[5]): +def evaluate(model, dataloader, tokenizer, device, amp=True): """ - Evaluate the model on the given dataset + Evaluate the model on the given dataset. + The task has N instances, each instance has I images and C captions. + For each instance, the goal is to find the correct image for each caption and the correct caption for each image. + This is done by computing the similarities between each image and each caption. + This procedure is used to evaluate the models on Winoground and SugarCrepe. Parameters ---------- @@ -28,32 +32,49 @@ def evaluate(model, dataloader, tokenizer, device, amp=True, recall_k_list=[5]) Returns ------- - dict of accuracy metric + dict of accuracy metrics """ autocast = torch.cuda.amp.autocast if amp else suppress - preds = [] + image_score = [] + text_score = [] + score = [] for batch_images, batch_texts in tqdm(dataloader): + if len(batch_images.shape) == 4: + B, C, H, W = batch_images.shape + batch_images = batch_images.view(B, 1, C, H, W) + # batch_images: B, nb_images_per_instance, C, H, W + # batch_texts: B, nb_captions_per_instance + + B, nim, C, H, W = batch_images.shape + nt = len(batch_texts[0]) batch_images = batch_images.to(device) + batch_images_ = batch_images.view(B*nim, C, H, W) # B*nim, C, H, W # tokenize all texts in the batch - batch_texts_tok = tokenizer([text for i, texts in enumerate(batch_texts) for text in texts]).to(device) - nb_texts_for_each_image = [len(texts) for texts in batch_texts] - + batch_texts_tok_ = tokenizer([text for i, texts in enumerate(batch_texts) for text in texts]).to(device) # compute the embedding of images and texts with torch.no_grad(), autocast(): - batch_images_emb = F.normalize(model.encode_image(batch_images), dim=-1).cpu() - batch_texts_emb = F.normalize(model.encode_text(batch_texts_tok), dim=-1).cpu() - start = 0 - for i, nb in enumerate(nb_texts_for_each_image): - end = start + nb - image_emb = batch_images_emb[i:i+1] - texts_emb = batch_texts_emb[start:end] - scores = image_emb @ texts_emb.t() - scores = scores[0] - pred = scores.argmax().item() - start = end - preds.append(pred) - pred = torch.Tensor(preds).long() - acc = (pred==0).float().mean().item() # 0 is the index of the caption, the rest (>0) are considered negative captions + batch_images_emb = F.normalize(model.encode_image(batch_images_), dim=-1).view(B, nim, -1) + batch_texts_emb = F.normalize(model.encode_text(batch_texts_tok_), dim=-1).view(B, nt, -1) + gt = torch.arange(min(nim, nt)).to(device) + for i in range(B): + # iteratve over instances + + # compute similarities between each image and each text + images_emb = batch_images_emb[i] + texts_emb = batch_texts_emb[i] + scores = images_emb @ texts_emb.t() + + # i-th image should be matched to the i-th text + image_closest_text = scores.argmax(dim=1)[:len(gt)] + text_closest_image = scores.argmax(dim=0)[:len(gt)] + pred_text_is_correct = (image_closest_text==gt).all().item() + pred_image_is_correct = (text_closest_image==gt).all().item() + all_correct = pred_text_is_correct and pred_image_is_correct + image_score.append(pred_image_is_correct) + text_score.append(pred_text_is_correct) + score.append(all_correct) metrics = {} - metrics[f"acc"] = acc + metrics["image_acc"] = torch.Tensor(image_score).float().mean().item() + metrics["text_acc"] = torch.Tensor(text_score).float().mean().item() + metrics["acc"] = torch.Tensor(score).float().mean().item() return metrics \ No newline at end of file