From 8a05786fbbe40fd4736464c5fc2721f6a8966666 Mon Sep 17 00:00:00 2001 From: Alexander Visheratin Date: Fri, 1 Dec 2023 19:01:45 -0500 Subject: [PATCH] Multilingual benchmark datasets (#113) * Updated XTD10 dataset. * Added Crossmodal-3600. * Added XTD200 and NLLB-CLIP. * Added Flickr30k-200. * Fixed datasets. * Fixed xm3600. * Added folders processing. * Fixed flickr30k-200. --- clip_benchmark/cli.py | 43 ++- clip_benchmark/datasets/builder.py | 92 ++++--- clip_benchmark/datasets/crossmodal3600.py | 152 +++++++++++ clip_benchmark/datasets/flickr30k_200.py | 118 +++++++++ clip_benchmark/datasets/flores_langs.py | 203 ++++++++++++++ .../datasets/multilingual_mscoco.py | 114 +++++--- clip_benchmark/datasets/xtd200.py | 119 +++++++++ clip_benchmark/models/nllb_clip.py | 248 ++++++++++++++++++ 8 files changed, 994 insertions(+), 95 deletions(-) create mode 100644 clip_benchmark/datasets/crossmodal3600.py create mode 100644 clip_benchmark/datasets/flickr30k_200.py create mode 100644 clip_benchmark/datasets/flores_langs.py create mode 100644 clip_benchmark/datasets/xtd200.py create mode 100644 clip_benchmark/models/nllb_clip.py diff --git a/clip_benchmark/cli.py b/clip_benchmark/cli.py index 8063339..d9434d7 100644 --- a/clip_benchmark/cli.py +++ b/clip_benchmark/cli.py @@ -1,17 +1,26 @@ """Console script for clip_benchmark.""" import argparse -import sys -import random -import json -import torch import csv -from copy import copy +import json import os +import random +import sys +from copy import copy from itertools import product -from clip_benchmark.datasets.builder import build_dataset, get_dataset_collate_fn, get_dataset_default_task, dataset_collection, get_dataset_collection_from_file -from clip_benchmark.metrics import image_caption_selection, zeroshot_classification, zeroshot_retrieval, linear_probe, captioning -from clip_benchmark.model_collection import get_model_collection_from_file, model_collection -from clip_benchmark.models import load_clip, MODEL_TYPES + +import torch + +from clip_benchmark.datasets.builder import (build_dataset, dataset_collection, + get_dataset_collate_fn, + get_dataset_collection_from_file, + get_dataset_default_task) +from clip_benchmark.metrics import (captioning, image_caption_selection, + linear_probe, zeroshot_classification, + zeroshot_retrieval) +from clip_benchmark.model_collection import (get_model_collection_from_file, + model_collection) +from clip_benchmark.models import MODEL_TYPES, load_clip + def get_parser_args(): parser = argparse.ArgumentParser() @@ -81,7 +90,7 @@ def main_build(base): # Build a benchmark single CSV file from a set of evaluations (JSON files) rows = [] fieldnames = set() - for path in base.files: + def process_file(path: str): data = json.load(open(path)) row = {} row.update(data["metrics"]) @@ -91,6 +100,13 @@ def main_build(base): for field in row.keys(): fieldnames.add(field) rows.append(row) + for path in base.files: + if os.path.isdir(path): + files = [os.path.join(path, f) for f in os.listdir(path) if f.endswith(".json")] + for file in files: + process_file(file) + else: + process_file(path) with open(base.output, 'w') as csvfile: writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() @@ -241,6 +257,11 @@ def run(args): device=args.device ) model.eval() + if args.model.count("nllb-clip") > 0: + # for NLLB-CLIP models, we need to set the language prior to running the tests + from clip_benchmark.models.nllb_clip import set_language + + set_language(tokenizer, args.language) dataset = build_dataset( dataset_name=args.dataset, root=dataset_root, @@ -295,7 +316,7 @@ def run(args): verbose=args.verbose, save_clf=args.save_clf, load_clfs=args.load_clfs, - ) + ) elif task == "zeroshot_retrieval": metrics = zeroshot_retrieval.evaluate( model, diff --git a/clip_benchmark/datasets/builder.py b/clip_benchmark/datasets/builder.py index 7fd3819..73d7771 100644 --- a/clip_benchmark/datasets/builder.py +++ b/clip_benchmark/datasets/builder.py @@ -1,20 +1,19 @@ +import json import os -import warnings import sys -import json +import warnings from subprocess import call -from collections import defaultdict + import torch -from torchvision.datasets import ( - VisionDataset, ImageFolder, - CIFAR10, CIFAR100, ImageNet, CocoCaptions, Flickr8k, Flickr30k, Food101, SUN397, - StanfordCars, FGVCAircraft, DTD, OxfordIIITPet, Caltech101, Flowers102, - MNIST, STL10, EuroSAT, GTSRB, Kitti, Country211, PCAM, RenderedSST2 -) - -from . import voc2007, flickr, caltech101, imagenetv2, objectnet, babel_imagenet, sugar_crepe from torch.utils.data import default_collate -from PIL import Image +from torchvision.datasets import (CIFAR10, CIFAR100, DTD, GTSRB, MNIST, PCAM, + STL10, SUN397, CocoCaptions, Country211, + EuroSAT, FGVCAircraft, Flowers102, Food101, + ImageFolder, ImageNet, OxfordIIITPet, + RenderedSST2, StanfordCars) + +from . import (babel_imagenet, caltech101, flickr, imagenetv2, objectnet, + sugar_crepe, voc2007) def build_dataset(dataset_name, root="root", transform=None, split="test", download=True, annotation_file=None, language="en", task="zeroshot_classification", wds_cache_dir=None, custom_classname_file=None, custom_template_file=None, **kwargs): @@ -108,7 +107,7 @@ def download_imagenet(r): elif dataset_name == "imagenet-w": assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}" from imagenet_w import AddWatermark - from torchvision.transforms import Normalize, CenterCrop + from torchvision.transforms import CenterCrop, Normalize if not os.path.exists(root): download_imagenet(root) index_normalize = None @@ -264,35 +263,44 @@ def download_imagenet(r): ds = CocoCaptions(root=root_split, annFile=annotation_file, transform=transform, **kwargs) elif dataset_name == 'multilingual_mscoco_captions': from clip_benchmark.datasets import multilingual_mscoco - if(language not in multilingual_mscoco.SUPPORTED_LANGUAGES): + if language not in multilingual_mscoco.SUPPORTED_LANGUAGES: raise ValueError("Unsupported language for multilingual_ms_coco:", language) - - def get_archive_name(target_split): - if target_split == "train": - return "train2014.zip" - elif target_split in ("val", "test"): - return "val2014.zip" - else: - raise ValueError(f"split should be `train` or `val` or `test` for `{dataset_name}`") - - def download_mscoco_split(target_split): - archive_name = get_archive_name(target_split) - root_split = os.path.join(root, archive_name.replace(".zip", "")) - if not os.path.exists(root_split): - print(f"Downloading mscoco_captions {archive_name}...") - if not os.path.exists(os.path.join(root, archive_name)): - call(f"wget http://images.cocodataset.org/zips/{archive_name} --output-document={root}/{archive_name}", shell=True) - call(f"unzip {root}/{archive_name} -d {root}", shell=True) - - # The multilingual MS-COCO uses images from various splits - for target_split in ['train', 'val', 'test']: - download_mscoco_split(target_split) - - annotation_file = os.path.join(root, multilingual_mscoco.CAPTIONS_FILE_NAME.format(language)) - if (os.path.exists(annotation_file) == False): + + annotation_file = os.path.join(root, multilingual_mscoco.OUTPUT_FILENAME_TEMPLATE.format(language)) + if not os.path.exists(annotation_file): multilingual_mscoco.create_annotation_file(root, language) ds = multilingual_mscoco.Multilingual_MSCOCO(root=root, ann_file=annotation_file, transform=transform, **kwargs) + elif dataset_name == 'crossmodal3600': + from clip_benchmark.datasets import crossmodal3600 + if language not in crossmodal3600.SUPPORTED_LANGUAGES: + raise ValueError("Unsupported language for Crossmodal-3600:", language) + + annotation_file = os.path.join(root, crossmodal3600.OUTPUT_FILENAME_TEMPLATE.format(language)) + if not os.path.exists(annotation_file): + crossmodal3600.create_annotation_file(root, language) + + ds = crossmodal3600.Crossmodal3600(root=root, ann_file=annotation_file, transform=transform, **kwargs) + elif dataset_name == 'xtd200': + from clip_benchmark.datasets import xtd200 + if language not in xtd200.SUPPORTED_LANGUAGES: + raise ValueError("Unsupported language for xtd200:", language) + + annotation_file = os.path.join(root, xtd200.OUTPUT_FILENAME_TEMPLATE.format(language)) + if not os.path.exists(annotation_file): + xtd200.create_annotation_file(root, language) + + ds = xtd200.XTD200(root=root, ann_file=annotation_file, transform=transform, **kwargs) + elif dataset_name == 'flickr30k-200': + from clip_benchmark.datasets import flickr30k_200 + if language not in flickr30k_200.SUPPORTED_LANGUAGES: + raise ValueError("Unsupported language for flickr30k-200:", language) + + annotation_file = os.path.join(root, flickr30k_200.OUTPUT_FILENAME_TEMPLATE.format(language)) + if not os.path.exists(annotation_file): + flickr30k_200.create_annotation_file(root, language) + + ds = flickr30k_200.Flickr30k_200(root=root, ann_file=annotation_file, transform=transform, **kwargs) elif dataset_name == "flickr30k": # downloadable from https://www.kaggle.com/datasets/adityajn105/flickr30k # https://github.com/mehdidc/retrieval_annotations/releases/tag/1.0.0(annotations) @@ -513,7 +521,7 @@ def __len__(self): return 1 def get_dataset_default_task(dataset): - if dataset in ("flickr30k", "flickr8k", "mscoco_captions", "multilingual_mscoco_captions"): + if dataset in ("flickr30k", "flickr8k", "mscoco_captions", "multilingual_mscoco_captions", "flickr30k-200", "crossmodal3600", "xtd200"): return "zeroshot_retrieval" elif dataset.startswith("sugar_crepe"): return "image_caption_selection" @@ -521,7 +529,7 @@ def get_dataset_default_task(dataset): return "zeroshot_classification" def get_dataset_collate_fn(dataset_name): - if dataset_name in ("mscoco_captions", "multilingual_mscoco_captions", "flickr30k", "flickr8k") or dataset_name.startswith("sugar_crepe"): + if dataset_name in ("mscoco_captions", "multilingual_mscoco_captions", "flickr30k", "flickr8k", "flickr30k-200", "crossmodal3600", "xtd200") or dataset_name.startswith("sugar_crepe"): return image_captions_collate_fn else: return default_collate @@ -535,7 +543,8 @@ def has_kaggle(): def build_vtab_dataset(dataset_name, transform, download=True, split="test", data_dir="root", classnames=[]): # Using VTAB splits instead of default TFDS splits - from .tfds import VTABIterableDataset, disable_gpus_on_tensorflow, download_tfds_dataset + from .tfds import (VTABIterableDataset, disable_gpus_on_tensorflow, + download_tfds_dataset) # avoid Tensorflow owning GPUs to not clash with PyTorch disable_gpus_on_tensorflow() @@ -648,6 +657,7 @@ def build_vtab_dataset(dataset_name, transform, download=True, split="test", dat classes = tfds_dataset._dataset_builder.info.features[task].names elif dataset_name == "sun397": from task_adaptation.data.sun397 import Sun397Data + #FIXME There is a problem in `sun397`, when TFDS tries download it # there is an image that cannot be decoded. For the time being # we will use torchvision's SUN397 instead. diff --git a/clip_benchmark/datasets/crossmodal3600.py b/clip_benchmark/datasets/crossmodal3600.py new file mode 100644 index 0000000..582d428 --- /dev/null +++ b/clip_benchmark/datasets/crossmodal3600.py @@ -0,0 +1,152 @@ +import codecs +import json +import os +from subprocess import call + +from PIL import Image +from torchvision.datasets import VisionDataset + +SUPPORTED_LANGUAGES = [ + "ar", + "bn", + "cs", + "da", + "de", + "el", + "en", + "es", + "fa", + "fi", + "fil", + "fr", + "he", + "hi", + "hr", + "hu", + "id", + "it", + "ja", + "ko", + "mi", + "nl", + "no", + "pl", + "pt", + "quz", + "ro", + "ru", + "sv", + "sw", + "te", + "th", + "tr", + "uk", + "vi", + "zh", +] + +CAPTIONS_DOWNLOAD_URL = "https://google.github.io/crossmodal-3600/web-data/captions.zip" +IMAGES_DOWNLOAD_URL = ( + "https://open-images-dataset.s3.amazonaws.com/crossmodal-3600/images.tgz" +) +OUTPUT_FILENAME_TEMPLATE = "crossmodal3600_captions-{}.json" + + +class Crossmodal3600(VisionDataset): + def __init__(self, root, ann_file, transform=None, target_transform=None): + super().__init__(root, transform=transform, target_transform=target_transform) + self.ann_file = os.path.expanduser(ann_file) + with codecs.open(ann_file, "r", encoding="utf-8") as fp: + data = json.load(fp) + self.data = [ + (img_path, txt) + for img_path, txt in zip(data["image_paths"], data["annotations"]) + ] + + def __getitem__(self, index): + img, captions = self.data[index] + + # Image + img = Image.open(img).convert("RGB") + if self.transform is not None: + img = self.transform(img) + + # Captions + target = [ + captions, + ] + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self) -> int: + return len(self.data) + + +def _download_captions(out_path): + os.makedirs(out_path, exist_ok=True) + print("Downloading captions") + call(f"wget {CAPTIONS_DOWNLOAD_URL} -O captions.zip", shell=True) + call(f"unzip captions.zip -d {out_path}", shell=True) + call("rm captions.zip", shell=True) + + +def _download_images(out_path): + os.makedirs(out_path, exist_ok=True) + print("Downloading images") + call(f"wget {IMAGES_DOWNLOAD_URL} -O images.tgz", shell=True) + call(f"tar -xzf images.tgz -C {out_path}", shell=True) + call("rm images.tgz", shell=True) + + +def create_annotation_file(root, lang_code): + if lang_code not in SUPPORTED_LANGUAGES: + raise ValueError( + f"Language code {lang_code} not supported. Supported languages are {SUPPORTED_LANGUAGES}" + ) + data_dir = os.path.join(root, "xm3600") + images_dir = os.path.join(data_dir, "images") + if not os.path.exists(images_dir): + _download_images(images_dir) + captions_path = os.path.join(data_dir, "captions.jsonl") + if not os.path.exists(captions_path): + _download_captions(data_dir) + with open(captions_path, "r", encoding="utf-8") as f: + data = f.readlines() + data = [json.loads(line) for line in data] + + number_of_missing_images = 0 + valid_images, valid_annotations, valid_indicies = [], [], [] + for i, data_item in enumerate(data): + image_id = data_item["image/key"] + image_name = f"{image_id}.jpg" + image_path = os.path.join(images_dir, image_name) + if not os.path.exists(image_path): + print("Missing image file", image_name) + number_of_missing_images += 1 + continue + captions = data_item[lang_code]["caption"] + txt = captions[0] + + valid_images.append(image_path) + valid_annotations.append(txt) + valid_indicies.append(i) + + if number_of_missing_images > 0: + print(f"*** WARNING *** missing {number_of_missing_images} files.") + + with codecs.open( + os.path.join(root, OUTPUT_FILENAME_TEMPLATE.format(lang_code)), + "w", + encoding="utf-8", + ) as fp: + json.dump( + { + "image_paths": valid_images, + "annotations": valid_annotations, + "indicies": valid_indicies, + }, + fp, + ensure_ascii=False, + ) diff --git a/clip_benchmark/datasets/flickr30k_200.py b/clip_benchmark/datasets/flickr30k_200.py new file mode 100644 index 0000000..c603f0f --- /dev/null +++ b/clip_benchmark/datasets/flickr30k_200.py @@ -0,0 +1,118 @@ +import codecs +import json +import os +from subprocess import call + +import requests +from PIL import Image +from torchvision.datasets import VisionDataset + +from .flores_langs import flores_languages + +GITHUB_DATA_PATH = ( + "https://raw.githubusercontent.com/visheratin/nllb-clip/main/data/flickr30k-200/" +) +SUPPORTED_LANGUAGES = flores_languages + +IMAGE_INDEX_FILENAME = "filenames.txt" + +CAPTIONS_FILENAME_TEMPLATE = "{}.txt" +OUTPUT_FILENAME_TEMPLATE = "flickr30k_200-{}.json" + +IMAGES_DOWNLOAD_URL = "https://nllb-data.com/test/flickr30k/images.tar.gz" + + +class Flickr30k_200(VisionDataset): + def __init__(self, root, ann_file, transform=None, target_transform=None): + super().__init__(root, transform=transform, target_transform=target_transform) + self.ann_file = os.path.expanduser(ann_file) + with codecs.open(ann_file, "r", encoding="utf-8") as fp: + data = json.load(fp) + self.data = [ + (img_path, txt) + for img_path, txt in zip(data["image_paths"], data["annotations"]) + ] + + def __getitem__(self, index): + img, captions = self.data[index] + + # Image + img = Image.open(img).convert("RGB") + if self.transform is not None: + img = self.transform(img) + + # Captions + target = [ + captions, + ] + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self) -> int: + return len(self.data) + + +def _get_lines(url): + response = requests.get(url, timeout=30) + return response.text.splitlines() + + +def _download_images(out_path): + os.makedirs(out_path, exist_ok=True) + print("Downloading images") + call(f"wget {IMAGES_DOWNLOAD_URL} -O images.tar.gz", shell=True) + call(f"tar -xzf images.tar.gz -C {out_path}", shell=True) + call("rm images.tar.gz", shell=True) + +def create_annotation_file(root, lang_code): + if lang_code not in SUPPORTED_LANGUAGES: + raise ValueError( + f"Language code {lang_code} not supported. Supported languages are {SUPPORTED_LANGUAGES}" + ) + data_dir = os.path.join(root, "flickr30k-200") + if not os.path.exists(data_dir): + _download_images(data_dir) + images_dir = os.path.join(root, "flickr30k-200", "images") + print("Downloading flickr30k-200 index file") + download_path = os.path.join(GITHUB_DATA_PATH, IMAGE_INDEX_FILENAME) + target_images = _get_lines(download_path) + + print("Downloading flickr30k-200 captions:", lang_code) + captions_path = GITHUB_DATA_PATH + download_path = os.path.join( + captions_path, CAPTIONS_FILENAME_TEMPLATE.format(lang_code) + ) + target_captions = _get_lines(download_path) + + number_of_missing_images = 0 + valid_images, valid_annotations, valid_indicies = [], [], [] + for i, (img, txt) in enumerate(zip(target_images, target_captions)): + image_path = os.path.join(images_dir, img) + if not os.path.exists(image_path): + print("Missing image file", img) + number_of_missing_images += 1 + continue + + valid_images.append(image_path) + valid_annotations.append(txt) + valid_indicies.append(i) + + if number_of_missing_images > 0: + print(f"*** WARNING *** missing {number_of_missing_images} files.") + + with codecs.open( + os.path.join(root, OUTPUT_FILENAME_TEMPLATE.format(lang_code)), + "w", + encoding="utf-8", + ) as fp: + json.dump( + { + "image_paths": valid_images, + "annotations": valid_annotations, + "indicies": valid_indicies, + }, + fp, + ensure_ascii=False, + ) diff --git a/clip_benchmark/datasets/flores_langs.py b/clip_benchmark/datasets/flores_langs.py new file mode 100644 index 0000000..928a3c1 --- /dev/null +++ b/clip_benchmark/datasets/flores_langs.py @@ -0,0 +1,203 @@ +flores_languages = [ + "ace_Arab", + "ace_Latn", + "acm_Arab", + "acq_Arab", + "aeb_Arab", + "afr_Latn", + "ajp_Arab", + "aka_Latn", + "amh_Ethi", + "apc_Arab", + "arb_Arab", + "ars_Arab", + "ary_Arab", + "arz_Arab", + "asm_Beng", + "ast_Latn", + "awa_Deva", + "ayr_Latn", + "azb_Arab", + "azj_Latn", + "bak_Cyrl", + "bam_Latn", + "ban_Latn", + "bel_Cyrl", + "bem_Latn", + "ben_Beng", + "bho_Deva", + "bjn_Arab", + "bjn_Latn", + "bod_Tibt", + "bos_Latn", + "bug_Latn", + "bul_Cyrl", + "cat_Latn", + "ceb_Latn", + "ces_Latn", + "cjk_Latn", + "ckb_Arab", + "crh_Latn", + "cym_Latn", + "dan_Latn", + "deu_Latn", + "dik_Latn", + "dyu_Latn", + "dzo_Tibt", + "eng_Latn", + "ell_Grek", + "epo_Latn", + "est_Latn", + "eus_Latn", + "ewe_Latn", + "fao_Latn", + "fij_Latn", + "fin_Latn", + "fon_Latn", + "fra_Latn", + "fur_Latn", + "fuv_Latn", + "gla_Latn", + "gle_Latn", + "glg_Latn", + "grn_Latn", + "guj_Gujr", + "hat_Latn", + "hau_Latn", + "heb_Hebr", + "hin_Deva", + "hne_Deva", + "hrv_Latn", + "hun_Latn", + "hye_Armn", + "ibo_Latn", + "ilo_Latn", + "ind_Latn", + "isl_Latn", + "ita_Latn", + "jav_Latn", + "jpn_Jpan", + "kab_Latn", + "kac_Latn", + "kam_Latn", + "kan_Knda", + "kas_Arab", + "kas_Deva", + "kat_Geor", + "knc_Arab", + "knc_Latn", + "kaz_Cyrl", + "kbp_Latn", + "kea_Latn", + "khm_Khmr", + "kik_Latn", + "kin_Latn", + "kir_Cyrl", + "kmb_Latn", + "kmr_Latn", + "kon_Latn", + "kor_Hang", + "lao_Laoo", + "lij_Latn", + "lim_Latn", + "lin_Latn", + "lit_Latn", + "lmo_Latn", + "ltg_Latn", + "ltz_Latn", + "lua_Latn", + "lug_Latn", + "luo_Latn", + "lus_Latn", + "lvs_Latn", + "mag_Deva", + "mai_Deva", + "mal_Mlym", + "mar_Deva", + "min_Latn", + "mkd_Cyrl", + "plt_Latn", + "mlt_Latn", + "mni_Beng", + "khk_Cyrl", + "mos_Latn", + "mri_Latn", + "mya_Mymr", + "nld_Latn", + "nno_Latn", + "nob_Latn", + "npi_Deva", + "nso_Latn", + "nus_Latn", + "nya_Latn", + "oci_Latn", + "gaz_Latn", + "ory_Orya", + "pag_Latn", + "pan_Guru", + "pap_Latn", + "pes_Arab", + "pol_Latn", + "por_Latn", + "prs_Arab", + "pbt_Arab", + "quy_Latn", + "ron_Latn", + "run_Latn", + "rus_Cyrl", + "sag_Latn", + "san_Deva", + "scn_Latn", + "shn_Mymr", + "sin_Sinh", + "slk_Latn", + "slv_Latn", + "smo_Latn", + "sna_Latn", + "snd_Arab", + "som_Latn", + "sot_Latn", + "spa_Latn", + "als_Latn", + "srd_Latn", + "srp_Cyrl", + "ssw_Latn", + "sun_Latn", + "swe_Latn", + "swh_Latn", + "szl_Latn", + "tam_Taml", + "tat_Cyrl", + "tel_Telu", + "tgk_Cyrl", + "tgl_Latn", + "tha_Thai", + "tir_Ethi", + "taq_Latn", + "taq_Tfng", + "tpi_Latn", + "tsn_Latn", + "tso_Latn", + "tuk_Latn", + "tum_Latn", + "tur_Latn", + "twi_Latn", + "tzm_Tfng", + "uig_Arab", + "ukr_Cyrl", + "umb_Latn", + "urd_Arab", + "uzn_Latn", + "vec_Latn", + "vie_Latn", + "war_Latn", + "wol_Latn", + "xho_Latn", + "ydd_Hebr", + "yor_Latn", + "yue_Hant", + "zho_Hans", + "zho_Hant", + "zsm_Latn", + "zul_Latn", +] diff --git a/clip_benchmark/datasets/multilingual_mscoco.py b/clip_benchmark/datasets/multilingual_mscoco.py index fd598b5..167257c 100644 --- a/clip_benchmark/datasets/multilingual_mscoco.py +++ b/clip_benchmark/datasets/multilingual_mscoco.py @@ -1,91 +1,119 @@ +import codecs +import json +import os from subprocess import call -import os, json -from torchvision.datasets import VisionDataset +import requests from PIL import Image +from torchvision.datasets import VisionDataset +GITHUB_DATA_PATH = "https://raw.githubusercontent.com/adobe-research/Cross-lingual-Test-Dataset-XTD10/main/XTD10/" +GITHUB_DATA_PATH_DE_FR = "https://raw.githubusercontent.com/adobe-research/Cross-lingual-Test-Dataset-XTD10/main/MIC/" +GITHUB_DATA_PATH_JP = "https://raw.githubusercontent.com/adobe-research/Cross-lingual-Test-Dataset-XTD10/main/STAIR/" +SUPPORTED_LANGUAGES = ["es", "it", "ko", "pl", "ru", "tr", "zh", "en", "de", "fr", "jp"] -GITHUB_MAIN_ORIGINAL_ANNOTATION_PATH = 'https://github.com/mehdidc/retrieval_annotations/releases/download/1.0.0/coco_{}_karpathy.json' -GITHUB_MAIN_PATH = 'https://raw.githubusercontent.com/adobe-research/Cross-lingual-Test-Dataset-XTD10/main/XTD10/' -SUPPORTED_LANGUAGES = ['es', 'it', 'ko', 'pl', 'ru', 'tr', 'zh', 'en'] - -IMAGE_INDEX_FILE = 'mscoco-multilingual_index.json' -IMAGE_INDEX_FILE_DOWNLOAD_NAME = 'test_image_names.txt' - -CAPTIONS_FILE_DOWNLOAD_NAME = 'test_1kcaptions_{}.txt' -CAPTIONS_FILE_NAME = 'multilingual_mscoco_captions-{}.json' +IMAGE_INDEX_FILENAME = "test_image_names.txt" -ORIGINAL_ANNOTATION_FILE_NAME = 'coco_{}_karpathy.json' +CAPTIONS_FILENAME_TEMPLATE = "test_1kcaptions_{}.txt" +OUTPUT_FILENAME_TEMPLATE = "multilingual_mscoco_captions-{}.json" +IMAGES_DOWNLOAD_URL = "https://nllb-data.com/test/xtd10/images.tar.gz" class Multilingual_MSCOCO(VisionDataset): - def __init__(self, root, ann_file, transform=None, target_transform=None): super().__init__(root, transform=transform, target_transform=target_transform) self.ann_file = os.path.expanduser(ann_file) - with open(ann_file, 'r') as fp: + with codecs.open(ann_file, "r", encoding="utf-8") as fp: data = json.load(fp) - - self.data = [(img_path, txt) for img_path, txt in zip(data['image_paths'], data['annotations'])] - + self.data = [ + (img_path, txt) + for img_path, txt in zip(data["image_paths"], data["annotations"]) + ] + def __getitem__(self, index): img, captions = self.data[index] # Image - img = Image.open(os.path.join(self.root, img)).convert("RGB") + img = Image.open(img).convert("RGB") if self.transform is not None: img = self.transform(img) # Captions - target = [captions, ] + target = [ + captions, + ] if self.target_transform is not None: target = self.target_transform(target) return img, target - def __len__(self) -> int: return len(self.data) -def _get_downloadable_file(filename, download_url, is_json=True): - if (os.path.exists(filename) == False): - print("Downloading", download_url) - call("wget {} -O {}".format(download_url, filename), shell=True) - with open(filename, 'r') as fp: - if (is_json): - return json.load(fp) - return [line.strip() for line in fp.readlines()] +def _get_lines(url): + response = requests.get(url, timeout=30) + return response.text.splitlines() + + +def _download_images(out_path): + os.makedirs(out_path, exist_ok=True) + print("Downloading images") + call(f"wget {IMAGES_DOWNLOAD_URL} -O images.tar.gz", shell=True) + call(f"tar -xzf images.tar.gz -C {out_path}", shell=True) + call("rm images.tar.gz", shell=True) def create_annotation_file(root, lang_code): + if lang_code not in SUPPORTED_LANGUAGES: + raise ValueError( + f"Language code {lang_code} not supported. Supported languages are {SUPPORTED_LANGUAGES}" + ) + data_dir = os.path.join(root, "multilingual_mscoco") + if not os.path.exists(data_dir): + _download_images(data_dir) + images_dir = os.path.join(data_dir, "images") print("Downloading multilingual_ms_coco index file") - download_path = os.path.join(GITHUB_MAIN_PATH, IMAGE_INDEX_FILE_DOWNLOAD_NAME) - target_images = _get_downloadable_file("multilingual_coco_images.txt", download_path, False) + download_path = os.path.join(GITHUB_DATA_PATH, IMAGE_INDEX_FILENAME) + target_images = _get_lines(download_path) print("Downloading multilingual_ms_coco captions:", lang_code) - download_path = os.path.join(GITHUB_MAIN_PATH, CAPTIONS_FILE_DOWNLOAD_NAME.format(lang_code)) - target_captions = _get_downloadable_file('raw_multilingual_coco_captions_{}.txt'.format(lang_code), download_path, False) + captions_path = GITHUB_DATA_PATH + if lang_code in ["de", "fr"]: + captions_path = GITHUB_DATA_PATH_DE_FR + elif lang_code == "jp": + captions_path = GITHUB_DATA_PATH_JP + download_path = os.path.join( + captions_path, CAPTIONS_FILENAME_TEMPLATE.format(lang_code) + ) + target_captions = _get_lines(download_path) number_of_missing_images = 0 valid_images, valid_annotations, valid_indicies = [], [], [] for i, (img, txt) in enumerate(zip(target_images, target_captions)): - # Create a new file name that includes the root split - root_split = 'val2014' if 'val' in img else 'train2014' - filename_with_root_split = "{}/{}".format(root_split, img) - - if (os.path.exists(filename_with_root_split)): + image_path = os.path.join(images_dir, img) + if not os.path.exists(image_path): print("Missing image file", img) number_of_missing_images += 1 continue - valid_images.append(filename_with_root_split) + valid_images.append(image_path) valid_annotations.append(txt) valid_indicies.append(i) - if (number_of_missing_images > 0): - print("*** WARNING *** missing {} files.".format(number_of_missing_images)) - - with open(os.path.join(root, CAPTIONS_FILE_NAME.format(lang_code)), 'w') as fp: - json.dump({'image_paths': valid_images, 'annotations': valid_annotations, 'indicies': valid_indicies}, fp) + if number_of_missing_images > 0: + print(f"*** WARNING *** missing {number_of_missing_images} files.") + + with codecs.open( + os.path.join(root, OUTPUT_FILENAME_TEMPLATE.format(lang_code)), "w", encoding="utf-8" + ) as fp: + json.dump( + { + "image_paths": valid_images, + "annotations": valid_annotations, + "indicies": valid_indicies, + }, + fp, + ensure_ascii=False, + ) diff --git a/clip_benchmark/datasets/xtd200.py b/clip_benchmark/datasets/xtd200.py new file mode 100644 index 0000000..935ac5a --- /dev/null +++ b/clip_benchmark/datasets/xtd200.py @@ -0,0 +1,119 @@ +import codecs +import json +import os +from subprocess import call + +import requests +from PIL import Image +from torchvision.datasets import VisionDataset + +from .flores_langs import flores_languages + +GITHUB_DATA_PATH = ( + "https://raw.githubusercontent.com/visheratin/nllb-clip/main/data/xtd200/" +) +SUPPORTED_LANGUAGES = flores_languages + +IMAGE_INDEX_FILENAME = "test_image_names.txt" + +CAPTIONS_FILENAME_TEMPLATE = "{}.txt" +OUTPUT_FILENAME_TEMPLATE = "xtd200-{}.json" + +IMAGES_DOWNLOAD_URL = "https://nllb-data.com/test/xtd10/images.tar.gz" + + +class XTD200(VisionDataset): + def __init__(self, root, ann_file, transform=None, target_transform=None): + super().__init__(root, transform=transform, target_transform=target_transform) + self.ann_file = os.path.expanduser(ann_file) + with codecs.open(ann_file, "r", encoding="utf-8") as fp: + data = json.load(fp) + self.data = [ + (img_path, txt) + for img_path, txt in zip(data["image_paths"], data["annotations"]) + ] + + def __getitem__(self, index): + img, captions = self.data[index] + + # Image + img = Image.open(img).convert("RGB") + if self.transform is not None: + img = self.transform(img) + + # Captions + target = [ + captions, + ] + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self) -> int: + return len(self.data) + + +def _get_lines(url): + response = requests.get(url, timeout=30) + return response.text.splitlines() + + +def _download_images(out_path): + os.makedirs(out_path, exist_ok=True) + print("Downloading images") + call(f"wget {IMAGES_DOWNLOAD_URL} -O images.tar.gz", shell=True) + call(f"tar -xzf images.tar.gz -C {out_path}", shell=True) + call("rm images.tar.gz", shell=True) + + +def create_annotation_file(root, lang_code): + if lang_code not in SUPPORTED_LANGUAGES: + raise ValueError( + f"Language code {lang_code} not supported. Supported languages are {SUPPORTED_LANGUAGES}" + ) + data_dir = os.path.join(root, "xtd200") + if not os.path.exists(data_dir): + _download_images(data_dir) + images_dir = os.path.join(data_dir, "images") + print("Downloading xtd200 index file") + download_path = os.path.join(GITHUB_DATA_PATH, IMAGE_INDEX_FILENAME) + target_images = _get_lines(download_path) + + print("Downloading xtd200 captions:", lang_code) + captions_path = GITHUB_DATA_PATH + download_path = os.path.join( + captions_path, CAPTIONS_FILENAME_TEMPLATE.format(lang_code) + ) + target_captions = _get_lines(download_path) + + number_of_missing_images = 0 + valid_images, valid_annotations, valid_indicies = [], [], [] + for i, (img, txt) in enumerate(zip(target_images, target_captions)): + image_path = os.path.join(images_dir, img) + if not os.path.exists(image_path): + print("Missing image file", img) + number_of_missing_images += 1 + continue + + valid_images.append(image_path) + valid_annotations.append(txt) + valid_indicies.append(i) + + if number_of_missing_images > 0: + print(f"*** WARNING *** missing {number_of_missing_images} files.") + + with codecs.open( + os.path.join(root, OUTPUT_FILENAME_TEMPLATE.format(lang_code)), + "w", + encoding="utf-8", + ) as fp: + json.dump( + { + "image_paths": valid_images, + "annotations": valid_annotations, + "indicies": valid_indicies, + }, + fp, + ensure_ascii=False, + ) diff --git a/clip_benchmark/models/nllb_clip.py b/clip_benchmark/models/nllb_clip.py new file mode 100644 index 0000000..cc6d540 --- /dev/null +++ b/clip_benchmark/models/nllb_clip.py @@ -0,0 +1,248 @@ +def set_language(tokenizer, lang_code): + lang = lang_map[lang_code] + print(f"Setting language for NLLB-CLIP: {lang}") + tokenizer.tokenizer.set_src_lang_special_tokens(lang) + + +lang_map = { + "en": "eng_Latn", + "es": "spa_Latn", + "it": "ita_Latn", + "ko": "kor_Hang", + "ru": "rus_Cyrl", + "zh": "zho_Hant", + "de": "deu_Latn", + "fr": "fra_Latn", + "jp": "jpn_Jpan", + "cn": "zho_Hant", + "zhm": "yue_Hant", + "ar": "arb_Arab", + "bn": "ben_Beng", + "cs": "ces_Latn", + "da": "dan_Latn", + "el": "ell_Grek", + "fa": "pes_Arab", + "fi": "fin_Latn", + "fil": "tgl_Latn", + "hi": "hin_Deva", + "hr": "hrv_Latn", + "hu": "hun_Latn", + "ja": "jpn_Jpan", + "id": "ind_Latn", + "he": "heb_Hebr", + "mi": "mri_Latn", + "nl": "nld_Latn", + "no": "nno_Latn", + "pl": "pol_Latn", + "pt": "por_Latn", + "quz": "quy_Latn", + "ro": "ron_Latn", + "sv": "swe_Latn", + "sw": "swh_Latn", + "te": "tel_Telu", + "th": "tha_Thai", + "tr": "tur_Latn", + "uk": "ukr_Cyrl", + "vi": "vie_Latn", + "ace_Arab": "ace_Arab", + "ace_Latn": "ace_Latn", + "acm_Arab": "acm_Arab", + "acq_Arab": "acq_Arab", + "aeb_Arab": "aeb_Arab", + "afr_Latn": "afr_Latn", + "ajp_Arab": "ajp_Arab", + "aka_Latn": "aka_Latn", + "amh_Ethi": "amh_Ethi", + "apc_Arab": "apc_Arab", + "arb_Arab": "arb_Arab", + "ars_Arab": "ars_Arab", + "ary_Arab": "ary_Arab", + "arz_Arab": "arz_Arab", + "asm_Beng": "asm_Beng", + "ast_Latn": "ast_Latn", + "awa_Deva": "awa_Deva", + "ayr_Latn": "ayr_Latn", + "azb_Arab": "azb_Arab", + "azj_Latn": "azj_Latn", + "bak_Cyrl": "bak_Cyrl", + "bam_Latn": "bam_Latn", + "ban_Latn": "ban_Latn", + "bel_Cyrl": "bel_Cyrl", + "bem_Latn": "bem_Latn", + "ben_Beng": "ben_Beng", + "bho_Deva": "bho_Deva", + "bjn_Arab": "bjn_Arab", + "bjn_Latn": "bjn_Latn", + "bod_Tibt": "bod_Tibt", + "bos_Latn": "bos_Latn", + "bug_Latn": "bug_Latn", + "bul_Cyrl": "bul_Cyrl", + "cat_Latn": "cat_Latn", + "ceb_Latn": "ceb_Latn", + "ces_Latn": "ces_Latn", + "cjk_Latn": "cjk_Latn", + "ckb_Arab": "ckb_Arab", + "crh_Latn": "crh_Latn", + "cym_Latn": "cym_Latn", + "dan_Latn": "dan_Latn", + "deu_Latn": "deu_Latn", + "dik_Latn": "dik_Latn", + "dyu_Latn": "dyu_Latn", + "dzo_Tibt": "dzo_Tibt", + "eng_Latn": "eng_Latn", + "ell_Grek": "ell_Grek", + "epo_Latn": "epo_Latn", + "est_Latn": "est_Latn", + "eus_Latn": "eus_Latn", + "ewe_Latn": "ewe_Latn", + "fao_Latn": "fao_Latn", + "fij_Latn": "fij_Latn", + "fin_Latn": "fin_Latn", + "fon_Latn": "fon_Latn", + "fra_Latn": "fra_Latn", + "fur_Latn": "fur_Latn", + "fuv_Latn": "fuv_Latn", + "gla_Latn": "gla_Latn", + "gle_Latn": "gle_Latn", + "glg_Latn": "glg_Latn", + "grn_Latn": "grn_Latn", + "guj_Gujr": "guj_Gujr", + "hat_Latn": "hat_Latn", + "hau_Latn": "hau_Latn", + "heb_Hebr": "heb_Hebr", + "hin_Deva": "hin_Deva", + "hne_Deva": "hne_Deva", + "hrv_Latn": "hrv_Latn", + "hun_Latn": "hun_Latn", + "hye_Armn": "hye_Armn", + "ibo_Latn": "ibo_Latn", + "ilo_Latn": "ilo_Latn", + "ind_Latn": "ind_Latn", + "isl_Latn": "isl_Latn", + "ita_Latn": "ita_Latn", + "jav_Latn": "jav_Latn", + "jpn_Jpan": "jpn_Jpan", + "kab_Latn": "kab_Latn", + "kac_Latn": "kac_Latn", + "kam_Latn": "kam_Latn", + "kan_Knda": "kan_Knda", + "kas_Arab": "kas_Arab", + "kas_Deva": "kas_Deva", + "kat_Geor": "kat_Geor", + "knc_Arab": "knc_Arab", + "knc_Latn": "knc_Latn", + "kaz_Cyrl": "kaz_Cyrl", + "kbp_Latn": "kbp_Latn", + "kea_Latn": "kea_Latn", + "khm_Khmr": "khm_Khmr", + "kik_Latn": "kik_Latn", + "kin_Latn": "kin_Latn", + "kir_Cyrl": "kir_Cyrl", + "kmb_Latn": "kmb_Latn", + "kmr_Latn": "kmr_Latn", + "kon_Latn": "kon_Latn", + "kor_Hang": "kor_Hang", + "lao_Laoo": "lao_Laoo", + "lij_Latn": "lij_Latn", + "lim_Latn": "lim_Latn", + "lin_Latn": "lin_Latn", + "lit_Latn": "lit_Latn", + "lmo_Latn": "lmo_Latn", + "ltg_Latn": "ltg_Latn", + "ltz_Latn": "ltz_Latn", + "lua_Latn": "lua_Latn", + "lug_Latn": "lug_Latn", + "luo_Latn": "luo_Latn", + "lus_Latn": "lus_Latn", + "lvs_Latn": "lvs_Latn", + "mag_Deva": "mag_Deva", + "mai_Deva": "mai_Deva", + "mal_Mlym": "mal_Mlym", + "mar_Deva": "mar_Deva", + "min_Latn": "min_Latn", + "mkd_Cyrl": "mkd_Cyrl", + "plt_Latn": "plt_Latn", + "mlt_Latn": "mlt_Latn", + "mni_Beng": "mni_Beng", + "khk_Cyrl": "khk_Cyrl", + "mos_Latn": "mos_Latn", + "mri_Latn": "mri_Latn", + "mya_Mymr": "mya_Mymr", + "nld_Latn": "nld_Latn", + "nno_Latn": "nno_Latn", + "nob_Latn": "nob_Latn", + "npi_Deva": "npi_Deva", + "nso_Latn": "nso_Latn", + "nus_Latn": "nus_Latn", + "nya_Latn": "nya_Latn", + "oci_Latn": "oci_Latn", + "gaz_Latn": "gaz_Latn", + "ory_Orya": "ory_Orya", + "pag_Latn": "pag_Latn", + "pan_Guru": "pan_Guru", + "pap_Latn": "pap_Latn", + "pes_Arab": "pes_Arab", + "pol_Latn": "pol_Latn", + "por_Latn": "por_Latn", + "prs_Arab": "prs_Arab", + "pbt_Arab": "pbt_Arab", + "quy_Latn": "quy_Latn", + "ron_Latn": "ron_Latn", + "run_Latn": "run_Latn", + "rus_Cyrl": "rus_Cyrl", + "sag_Latn": "sag_Latn", + "san_Deva": "san_Deva", + "scn_Latn": "scn_Latn", + "shn_Mymr": "shn_Mymr", + "sin_Sinh": "sin_Sinh", + "slk_Latn": "slk_Latn", + "slv_Latn": "slv_Latn", + "smo_Latn": "smo_Latn", + "sna_Latn": "sna_Latn", + "snd_Arab": "snd_Arab", + "som_Latn": "som_Latn", + "sot_Latn": "sot_Latn", + "spa_Latn": "spa_Latn", + "als_Latn": "als_Latn", + "srd_Latn": "srd_Latn", + "srp_Cyrl": "srp_Cyrl", + "ssw_Latn": "ssw_Latn", + "sun_Latn": "sun_Latn", + "swe_Latn": "swe_Latn", + "swh_Latn": "swh_Latn", + "szl_Latn": "szl_Latn", + "tam_Taml": "tam_Taml", + "tat_Cyrl": "tat_Cyrl", + "tel_Telu": "tel_Telu", + "tgk_Cyrl": "tgk_Cyrl", + "tgl_Latn": "tgl_Latn", + "tha_Thai": "tha_Thai", + "tir_Ethi": "tir_Ethi", + "taq_Latn": "taq_Latn", + "taq_Tfng": "taq_Tfng", + "tpi_Latn": "tpi_Latn", + "tsn_Latn": "tsn_Latn", + "tso_Latn": "tso_Latn", + "tuk_Latn": "tuk_Latn", + "tum_Latn": "tum_Latn", + "tur_Latn": "tur_Latn", + "twi_Latn": "twi_Latn", + "tzm_Tfng": "tzm_Tfng", + "uig_Arab": "uig_Arab", + "ukr_Cyrl": "ukr_Cyrl", + "umb_Latn": "umb_Latn", + "urd_Arab": "urd_Arab", + "uzn_Latn": "uzn_Latn", + "vec_Latn": "vec_Latn", + "vie_Latn": "vie_Latn", + "war_Latn": "war_Latn", + "wol_Latn": "wol_Latn", + "xho_Latn": "xho_Latn", + "ydd_Hebr": "ydd_Hebr", + "yor_Latn": "yor_Latn", + "yue_Hant": "yue_Hant", + "zho_Hans": "zho_Hans", + "zho_Hant": "zho_Hant", + "zsm_Latn": "zsm_Latn", + "zul_Latn": "zul_Latn", +}