diff --git a/clip_benchmark/datasets/builder.py b/clip_benchmark/datasets/builder.py index 0f6685e..0e659d9 100644 --- a/clip_benchmark/datasets/builder.py +++ b/clip_benchmark/datasets/builder.py @@ -97,15 +97,19 @@ def download_imagenet(r): train = (split == "train") if dataset_name == "cifar10": + assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}" ds = CIFAR10(root=root, train=train, transform=transform, download=download, **kwargs) elif dataset_name == "cifar100": + assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}" ds = CIFAR100(root=root, train=train, transform=transform, download=download, **kwargs) elif dataset_name == "imagenet1k": + assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}" if not os.path.exists(root): download_imagenet(root) ds = ImageNet(root=root, split="train" if train else "val", transform=transform, **kwargs) ds.classes = classnames["imagenet1k"] elif dataset_name == "imagenet-w": + assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}" from imagenet_w import AddWatermark from torchvision.transforms import Normalize, CenterCrop if not os.path.exists(root): @@ -123,6 +127,7 @@ def download_imagenet(r): ds = ImageNet(root=root, split="train" if train else "val", transform=transform, **kwargs) ds.classes = classnames["imagenet1k"] elif dataset_name == "babel_imagenet": + assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}" # babel ImageNet from https://github.com/gregor-ge/Babel-ImageNet if not os.path.exists(root): download_imagenet(root) @@ -130,17 +135,18 @@ def download_imagenet(r): ds = babel_imagenet.BabelImageNet(root=root, idxs=idxs, split="train" if train else "val", transform=transform, **kwargs) ds.classes = classnames elif dataset_name == "imagenet1k-unverified": + assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}" split = "train" if train else "val" ds = ImageFolder(root=os.path.join(root, split), transform=transform, **kwargs) # use classnames from OpenAI ds.classes = classnames["imagenet1k"] elif dataset_name == "imagenetv2": - assert split == "test", f"Only test split available for {dataset_name}" + assert split == "test", f"Only `test` split available for {dataset_name}" os.makedirs(root, exist_ok=True) ds = imagenetv2.ImageNetV2Dataset(variant="matched-frequency", transform=transform, location=root) ds.classes = classnames["imagenet1k"] elif dataset_name == "imagenet_sketch": - assert split == "test", f"Only test split available for {dataset_name}" + assert split == "test", f"Only `test` split available for {dataset_name}" # Downloadable from https://drive.google.com/open?id=1Mj0i5HBthqH1p_yeXzsg22gZduvgoNeA if not os.path.exists(root): # Automatic download @@ -157,7 +163,7 @@ def download_imagenet(r): ds = ImageFolder(root=root, transform=transform, **kwargs) ds.classes = classnames["imagenet1k"] elif dataset_name == "imagenet-a": - assert split == "test", f"Only test split available for {dataset_name}" + assert split == "test", f"Only `test` split available for {dataset_name}" # Downloadable from https://people.eecs.berkeley.edu/~hendrycks/imagenet-a.tar if not os.path.exists(root): print("Downloading imagenet-a...") @@ -171,7 +177,7 @@ def download_imagenet(r): imagenet_a_mask = [wnid in set(imagenet_a_wnids) for wnid in all_imagenet_wordnet_ids] ds.classes = [cl for cl, mask in zip(ds.classes, imagenet_a_mask) if mask] elif dataset_name == "imagenet-r": - assert split == "test", f"Only test split available for {dataset_name}" + assert split == "test", f"Only `test` split available for {dataset_name}" # downloadable from https://people.eecs.berkeley.edu/~hendrycks/imagenet-r.tar if not os.path.exists(root): print("Downloading imagenet-r...") @@ -185,7 +191,7 @@ def download_imagenet(r): ds.classes = classnames["imagenet1k"] ds.classes = [cl for cl, mask in zip(ds.classes, imagenet_r_mask) if mask] elif dataset_name == "imagenet-o": - assert split == "test", f"Only test split available for {dataset_name}" + assert split == "test", f"Only `test` split available for {dataset_name}" # downloadable from https://people.eecs.berkeley.edu/~hendrycks/imagenet-o.tar if not os.path.exists(root): print("Downloading imagenet-o...") @@ -199,7 +205,7 @@ def download_imagenet(r): imagenet_o_mask = [wnid in set(imagenet_o_wnids) for wnid in all_imagenet_wordnet_ids] ds.classes = [cl for cl, mask in zip(ds.classes, imagenet_o_mask) if mask] elif dataset_name == "objectnet": - assert split == "test", f"Only test split available for {dataset_name}" + assert split == "test", f"Only `test` split available for {dataset_name}" # downloadable from https://objectnet.dev/downloads/objectnet-1.0.zip or https://www.dropbox.com/s/raw/cxeztdtm16nzvuw/objectnet-1.0.zip if not os.path.exists(root): print("Downloading objectnet...") @@ -211,14 +217,16 @@ def download_imagenet(r): call(f"cp {root}/objectnet-1.0/mappings/* {root}", shell=True) ds = objectnet.ObjectNetDataset(root=root, transform=transform) elif dataset_name == "voc2007": - ds = voc2007.PASCALVoc2007Cropped(root=root, set="train" if train else "test", transform=transform, download=download, **kwargs) + assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}" + ds = voc2007.PASCALVoc2007Cropped(root=root, set=split, transform=transform, download=download, **kwargs) elif dataset_name == "voc2007_multilabel": - ds = voc2007.PASCALVoc2007(root=root, set="train" if train else "test", transform=transform, download=download, **kwargs) + assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}" + ds = voc2007.PASCALVoc2007(root=root, set=split, transform=transform, download=download, **kwargs) elif dataset_name.startswith("sugar_crepe"): # https://github.com/RAIVNLab/sugar-crepe/tree/main _, task = dataset_name.split("/") assert task in ("add_att", "add_obj", "replace_att", "replace_obj", "replace_rel", "swap_att", "swap_obj"), f"Unknown task {task} for {dataset_name}" - assert split == "test", f"Only test split available for {dataset_name}" + assert split == "test", f"Only `test` split available for {dataset_name}" archive_name = "val2017.zip" root_split = os.path.join(root, archive_name.replace(".zip", "")) if not os.path.exists(root_split): @@ -238,7 +246,7 @@ def download_imagenet(r): elif split in ("val", "test"): archive_name = "val2014.zip" else: - raise ValueError(f"split should be train or val or test for `{dataset_name}`") + raise ValueError(f"split should be `train` or `val` or `test` for `{dataset_name}`") root_split = os.path.join(root, archive_name.replace(".zip", "")) if not os.path.exists(root_split): print(f"Downloading mscoco_captions {archive_name}...") @@ -261,7 +269,7 @@ def get_archive_name(target_split): elif target_split in ("val", "test"): return "val2014.zip" else: - raise ValueError(f"split should be train or val or test for `{dataset_name}`") + raise ValueError(f"split should be `train` or `val` or `test` for `{dataset_name}`") def download_mscoco_split(target_split): archive_name = get_archive_name(target_split) @@ -285,6 +293,7 @@ def download_mscoco_split(target_split): # downloadable from https://www.kaggle.com/datasets/adityajn105/flickr30k # https://github.com/mehdidc/retrieval_annotations/releases/tag/1.0.0(annotations) # `kaggle datasets download -d adityajn105/flickr30k` + assert split in ("train", "val", "test"), f"Only `train` and `val` and `test` split available for {dataset_name}" if not os.path.exists(root): # Automatic download print("Downloading flickr30k...") @@ -311,6 +320,7 @@ def download_mscoco_split(target_split): raise ValueError(f"Unsupported language {language} for `{dataset_name}`") ds = flickr.Flickr(root=root, ann_file=annotation_file, transform=transform, **kwargs) elif dataset_name == "flickr8k": + assert split in ("train", "val", "test"), f"Only `train` and `val` and `test` split available for {dataset_name}" # downloadable from https://www.kaggle.com/datasets/adityajn105/flickr8k # `kaggle datasets download -d adityajn105/flickr8k` # https://github.com/mehdidc/retrieval_annotations/releases/tag/1.0.0(annotations) @@ -341,7 +351,8 @@ def download_mscoco_split(target_split): raise ValueError(f"Unsupported language {language} for `{dataset_name}`") ds = flickr.Flickr(root=root, ann_file=annotation_file, transform=transform, **kwargs) elif dataset_name == "food101": - ds = Food101(root=root, split="train" if train else "test", transform=transform, download=download, **kwargs) + assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}" + ds = Food101(root=root, split=split, transform=transform, download=download, **kwargs) # we use the default class names, we just replace "_" by spaces # to delimit words ds.classes = [cl.replace("_", " ") for cl in ds.classes] @@ -352,13 +363,17 @@ def download_mscoco_split(target_split): ds = SUN397(root=root, transform=transform, download=download, **kwargs) ds.classes = [cl.replace("_", " ").replace("/", " ") for cl in ds.classes] elif dataset_name == "cars": - ds = StanfordCars(root=root, split="train" if train else "test", transform=transform, download=download, **kwargs) + assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}" + ds = StanfordCars(root=root, split=split, transform=transform, download=download, **kwargs) elif dataset_name == "fgvc_aircraft": - ds = FGVCAircraft(root=root, annotation_level="variant", split="train" if train else "test", transform=transform, download=download, **kwargs) + assert split in ("train", "val", "trainval", "test"), f"Only `train` and `val` and `trainval` and `test` split available for {dataset_name}" + ds = FGVCAircraft(root=root, annotation_level="variant", split=split, transform=transform, download=download, **kwargs) elif dataset_name == "dtd": - ds = DTD(root=root, split="train" if train else "test", transform=transform, download=download, **kwargs) + assert split in ("train", "val", "test"), f"Only `train` and `val` and `test` split available for {dataset_name}" + ds = DTD(root=root, split=split, transform=transform, download=download, **kwargs) elif dataset_name == "pets": - ds = OxfordIIITPet(root=root, split="train" if train else "test", target_types="category", transform=transform, download=download, **kwargs) + assert split in ("trainval", "test"), f"Only `trainval` and `test` split available for {dataset_name}" + ds = OxfordIIITPet(root=root, split=split, target_types="category", transform=transform, download=download, **kwargs) elif dataset_name == "caltech101": warnings.warn(f"split argument ignored for `{dataset_name}`, there are no pre-defined train/test splits for this dataset") # broken download link (can't download google drive), fixed by this PR https://github.com/pytorch/vision/pull/5645 @@ -367,7 +382,8 @@ def download_mscoco_split(target_split): ds = caltech101.Caltech101(root=root, target_type="category", transform=transform, download=download, **kwargs) ds.classes = classnames["caltech101"] elif dataset_name == "flowers": - ds = Flowers102(root=root, split="train" if train else "test", transform=transform, download=download, **kwargs) + assert split in ("train", "val", "test"), f"Only `train` and `val` and `test` split available for {dataset_name}" + ds = Flowers102(root=root, split=split, transform=transform, download=download, **kwargs) # class indices started by 1 until it was fixed in a PR (#TODO link of the PR) # if older torchvision version, fix it using a target transform that decrements label index # TODO figure out minimal torchvision version needed instead of decrementing @@ -375,28 +391,35 @@ def download_mscoco_split(target_split): ds.target_transform = lambda y:y-1 ds.classes = classnames["flowers"] elif dataset_name == "mnist": + assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}" ds = MNIST(root=root, train=train, transform=transform, download=download, **kwargs) ds.classes = classnames["mnist"] elif dataset_name == "stl10": - ds = STL10(root=root, split="train" if train else "test", transform=transform, download=download, **kwargs) + assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}" + ds = STL10(root=root, split=split, transform=transform, download=download, **kwargs) elif dataset_name == "eurosat": warnings.warn(f"split argument ignored for `{dataset_name}`, there are no pre-defined train/test splits for this dataset") ds = EuroSAT(root=root, transform=transform, download=download, **kwargs) ds.classes = classnames["eurosat"] elif dataset_name == "gtsrb": - ds = GTSRB(root=root, split="train" if train else "test", transform=transform, download=download, **kwargs) + assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}" + ds = GTSRB(root=root, split=split, transform=transform, download=download, **kwargs) ds.classes = classnames["gtsrb"] elif dataset_name == "country211": - ds = Country211(root=root, split="train" if train else "test", transform=transform, download=download, **kwargs) + assert split in ("train", "valid", "test"), f"Only `train` and `valid` and `test` split available for {dataset_name}" + ds = Country211(root=root, split=split, transform=transform, download=download, **kwargs) ds.classes = classnames["country211"] elif dataset_name == "pcam": + assert split in ("train", "val", "test"), f"Only `train` and `val` and `test` split available for {dataset_name}" # Dead link. Fixed by this PR on torchvision https://github.com/pytorch/vision/pull/5645 # TODO figure out minimal torchvision version needed - ds = PCAM(root=root, split="train" if train else "test", transform=transform, download=download, **kwargs) + ds = PCAM(root=root, split=split, transform=transform, download=download, **kwargs) ds.classes = classnames["pcam"] elif dataset_name == "renderedsst2": - ds = RenderedSST2(root=root, split="train" if train else "test", transform=transform, download=download, **kwargs) + assert split in ("train", "val", "test"), f"Only `train` and `val` and `test` split available for {dataset_name}" + ds = RenderedSST2(root=root, split=split, transform=transform, download=download, **kwargs) elif dataset_name == "fer2013": + assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}" # Downloadable from https://www.kaggle.com/datasets/msambare/fer2013 # `kaggle datasets download -d msambare/fer2013` if not os.path.exists(root):