Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix incorrect and missing splits #106

Merged
merged 1 commit into from
Aug 16, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 45 additions & 22 deletions clip_benchmark/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,19 @@ def download_imagenet(r):

train = (split == "train")
if dataset_name == "cifar10":
assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}"
ds = CIFAR10(root=root, train=train, transform=transform, download=download, **kwargs)
elif dataset_name == "cifar100":
assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}"
ds = CIFAR100(root=root, train=train, transform=transform, download=download, **kwargs)
elif dataset_name == "imagenet1k":
assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}"
if not os.path.exists(root):
download_imagenet(root)
ds = ImageNet(root=root, split="train" if train else "val", transform=transform, **kwargs)
ds.classes = classnames["imagenet1k"]
elif dataset_name == "imagenet-w":
assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}"
from imagenet_w import AddWatermark
from torchvision.transforms import Normalize, CenterCrop
if not os.path.exists(root):
Expand All @@ -123,24 +127,26 @@ def download_imagenet(r):
ds = ImageNet(root=root, split="train" if train else "val", transform=transform, **kwargs)
ds.classes = classnames["imagenet1k"]
elif dataset_name == "babel_imagenet":
assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}"
# babel ImageNet from https://github.com/gregor-ge/Babel-ImageNet
if not os.path.exists(root):
download_imagenet(root)
idxs, classnames = classnames
ds = babel_imagenet.BabelImageNet(root=root, idxs=idxs, split="train" if train else "val", transform=transform, **kwargs)
ds.classes = classnames
elif dataset_name == "imagenet1k-unverified":
assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}"
split = "train" if train else "val"
ds = ImageFolder(root=os.path.join(root, split), transform=transform, **kwargs)
# use classnames from OpenAI
ds.classes = classnames["imagenet1k"]
elif dataset_name == "imagenetv2":
assert split == "test", f"Only test split available for {dataset_name}"
assert split == "test", f"Only `test` split available for {dataset_name}"
os.makedirs(root, exist_ok=True)
ds = imagenetv2.ImageNetV2Dataset(variant="matched-frequency", transform=transform, location=root)
ds.classes = classnames["imagenet1k"]
elif dataset_name == "imagenet_sketch":
assert split == "test", f"Only test split available for {dataset_name}"
assert split == "test", f"Only `test` split available for {dataset_name}"
# Downloadable from https://drive.google.com/open?id=1Mj0i5HBthqH1p_yeXzsg22gZduvgoNeA
if not os.path.exists(root):
# Automatic download
Expand All @@ -157,7 +163,7 @@ def download_imagenet(r):
ds = ImageFolder(root=root, transform=transform, **kwargs)
ds.classes = classnames["imagenet1k"]
elif dataset_name == "imagenet-a":
assert split == "test", f"Only test split available for {dataset_name}"
assert split == "test", f"Only `test` split available for {dataset_name}"
# Downloadable from https://people.eecs.berkeley.edu/~hendrycks/imagenet-a.tar
if not os.path.exists(root):
print("Downloading imagenet-a...")
Expand All @@ -171,7 +177,7 @@ def download_imagenet(r):
imagenet_a_mask = [wnid in set(imagenet_a_wnids) for wnid in all_imagenet_wordnet_ids]
ds.classes = [cl for cl, mask in zip(ds.classes, imagenet_a_mask) if mask]
elif dataset_name == "imagenet-r":
assert split == "test", f"Only test split available for {dataset_name}"
assert split == "test", f"Only `test` split available for {dataset_name}"
# downloadable from https://people.eecs.berkeley.edu/~hendrycks/imagenet-r.tar
if not os.path.exists(root):
print("Downloading imagenet-r...")
Expand All @@ -185,7 +191,7 @@ def download_imagenet(r):
ds.classes = classnames["imagenet1k"]
ds.classes = [cl for cl, mask in zip(ds.classes, imagenet_r_mask) if mask]
elif dataset_name == "imagenet-o":
assert split == "test", f"Only test split available for {dataset_name}"
assert split == "test", f"Only `test` split available for {dataset_name}"
# downloadable from https://people.eecs.berkeley.edu/~hendrycks/imagenet-o.tar
if not os.path.exists(root):
print("Downloading imagenet-o...")
Expand All @@ -199,7 +205,7 @@ def download_imagenet(r):
imagenet_o_mask = [wnid in set(imagenet_o_wnids) for wnid in all_imagenet_wordnet_ids]
ds.classes = [cl for cl, mask in zip(ds.classes, imagenet_o_mask) if mask]
elif dataset_name == "objectnet":
assert split == "test", f"Only test split available for {dataset_name}"
assert split == "test", f"Only `test` split available for {dataset_name}"
# downloadable from https://objectnet.dev/downloads/objectnet-1.0.zip or https://www.dropbox.com/s/raw/cxeztdtm16nzvuw/objectnet-1.0.zip
if not os.path.exists(root):
print("Downloading objectnet...")
Expand All @@ -211,14 +217,16 @@ def download_imagenet(r):
call(f"cp {root}/objectnet-1.0/mappings/* {root}", shell=True)
ds = objectnet.ObjectNetDataset(root=root, transform=transform)
elif dataset_name == "voc2007":
ds = voc2007.PASCALVoc2007Cropped(root=root, set="train" if train else "test", transform=transform, download=download, **kwargs)
assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}"
ds = voc2007.PASCALVoc2007Cropped(root=root, set=split, transform=transform, download=download, **kwargs)
elif dataset_name == "voc2007_multilabel":
ds = voc2007.PASCALVoc2007(root=root, set="train" if train else "test", transform=transform, download=download, **kwargs)
assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}"
ds = voc2007.PASCALVoc2007(root=root, set=split, transform=transform, download=download, **kwargs)
elif dataset_name.startswith("sugar_crepe"):
# https://github.com/RAIVNLab/sugar-crepe/tree/main
_, task = dataset_name.split("/")
assert task in ("add_att", "add_obj", "replace_att", "replace_obj", "replace_rel", "swap_att", "swap_obj"), f"Unknown task {task} for {dataset_name}"
assert split == "test", f"Only test split available for {dataset_name}"
assert split == "test", f"Only `test` split available for {dataset_name}"
archive_name = "val2017.zip"
root_split = os.path.join(root, archive_name.replace(".zip", ""))
if not os.path.exists(root_split):
Expand All @@ -238,7 +246,7 @@ def download_imagenet(r):
elif split in ("val", "test"):
archive_name = "val2014.zip"
else:
raise ValueError(f"split should be train or val or test for `{dataset_name}`")
raise ValueError(f"split should be `train` or `val` or `test` for `{dataset_name}`")
root_split = os.path.join(root, archive_name.replace(".zip", ""))
if not os.path.exists(root_split):
print(f"Downloading mscoco_captions {archive_name}...")
Expand All @@ -261,7 +269,7 @@ def get_archive_name(target_split):
elif target_split in ("val", "test"):
return "val2014.zip"
else:
raise ValueError(f"split should be train or val or test for `{dataset_name}`")
raise ValueError(f"split should be `train` or `val` or `test` for `{dataset_name}`")

def download_mscoco_split(target_split):
archive_name = get_archive_name(target_split)
Expand All @@ -285,6 +293,7 @@ def download_mscoco_split(target_split):
# downloadable from https://www.kaggle.com/datasets/adityajn105/flickr30k
# https://github.com/mehdidc/retrieval_annotations/releases/tag/1.0.0(annotations)
# `kaggle datasets download -d adityajn105/flickr30k`
assert split in ("train", "val", "test"), f"Only `train` and `val` and `test` split available for {dataset_name}"
if not os.path.exists(root):
# Automatic download
print("Downloading flickr30k...")
Expand All @@ -311,6 +320,7 @@ def download_mscoco_split(target_split):
raise ValueError(f"Unsupported language {language} for `{dataset_name}`")
ds = flickr.Flickr(root=root, ann_file=annotation_file, transform=transform, **kwargs)
elif dataset_name == "flickr8k":
assert split in ("train", "val", "test"), f"Only `train` and `val` and `test` split available for {dataset_name}"
# downloadable from https://www.kaggle.com/datasets/adityajn105/flickr8k
# `kaggle datasets download -d adityajn105/flickr8k`
# https://github.com/mehdidc/retrieval_annotations/releases/tag/1.0.0(annotations)
Expand Down Expand Up @@ -341,7 +351,8 @@ def download_mscoco_split(target_split):
raise ValueError(f"Unsupported language {language} for `{dataset_name}`")
ds = flickr.Flickr(root=root, ann_file=annotation_file, transform=transform, **kwargs)
elif dataset_name == "food101":
ds = Food101(root=root, split="train" if train else "test", transform=transform, download=download, **kwargs)
assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}"
ds = Food101(root=root, split=split, transform=transform, download=download, **kwargs)
# we use the default class names, we just replace "_" by spaces
# to delimit words
ds.classes = [cl.replace("_", " ") for cl in ds.classes]
Expand All @@ -352,13 +363,17 @@ def download_mscoco_split(target_split):
ds = SUN397(root=root, transform=transform, download=download, **kwargs)
ds.classes = [cl.replace("_", " ").replace("/", " ") for cl in ds.classes]
elif dataset_name == "cars":
ds = StanfordCars(root=root, split="train" if train else "test", transform=transform, download=download, **kwargs)
assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}"
ds = StanfordCars(root=root, split=split, transform=transform, download=download, **kwargs)
elif dataset_name == "fgvc_aircraft":
ds = FGVCAircraft(root=root, annotation_level="variant", split="train" if train else "test", transform=transform, download=download, **kwargs)
assert split in ("train", "val", "trainval", "test"), f"Only `train` and `val` and `trainval` and `test` split available for {dataset_name}"
ds = FGVCAircraft(root=root, annotation_level="variant", split=split, transform=transform, download=download, **kwargs)
elif dataset_name == "dtd":
ds = DTD(root=root, split="train" if train else "test", transform=transform, download=download, **kwargs)
assert split in ("train", "val", "test"), f"Only `train` and `val` and `test` split available for {dataset_name}"
ds = DTD(root=root, split=split, transform=transform, download=download, **kwargs)
elif dataset_name == "pets":
ds = OxfordIIITPet(root=root, split="train" if train else "test", target_types="category", transform=transform, download=download, **kwargs)
assert split in ("trainval", "test"), f"Only `trainval` and `test` split available for {dataset_name}"
ds = OxfordIIITPet(root=root, split=split, target_types="category", transform=transform, download=download, **kwargs)
elif dataset_name == "caltech101":
warnings.warn(f"split argument ignored for `{dataset_name}`, there are no pre-defined train/test splits for this dataset")
# broken download link (can't download google drive), fixed by this PR https://github.com/pytorch/vision/pull/5645
Expand All @@ -367,36 +382,44 @@ def download_mscoco_split(target_split):
ds = caltech101.Caltech101(root=root, target_type="category", transform=transform, download=download, **kwargs)
ds.classes = classnames["caltech101"]
elif dataset_name == "flowers":
ds = Flowers102(root=root, split="train" if train else "test", transform=transform, download=download, **kwargs)
assert split in ("train", "val", "test"), f"Only `train` and `val` and `test` split available for {dataset_name}"
ds = Flowers102(root=root, split=split, transform=transform, download=download, **kwargs)
# class indices started by 1 until it was fixed in a PR (#TODO link of the PR)
# if older torchvision version, fix it using a target transform that decrements label index
# TODO figure out minimal torchvision version needed instead of decrementing
if ds[0][1] == 1:
ds.target_transform = lambda y:y-1
ds.classes = classnames["flowers"]
elif dataset_name == "mnist":
assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}"
ds = MNIST(root=root, train=train, transform=transform, download=download, **kwargs)
ds.classes = classnames["mnist"]
elif dataset_name == "stl10":
ds = STL10(root=root, split="train" if train else "test", transform=transform, download=download, **kwargs)
assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}"
ds = STL10(root=root, split=split, transform=transform, download=download, **kwargs)
elif dataset_name == "eurosat":
warnings.warn(f"split argument ignored for `{dataset_name}`, there are no pre-defined train/test splits for this dataset")
ds = EuroSAT(root=root, transform=transform, download=download, **kwargs)
ds.classes = classnames["eurosat"]
elif dataset_name == "gtsrb":
ds = GTSRB(root=root, split="train" if train else "test", transform=transform, download=download, **kwargs)
assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}"
ds = GTSRB(root=root, split=split, transform=transform, download=download, **kwargs)
ds.classes = classnames["gtsrb"]
elif dataset_name == "country211":
ds = Country211(root=root, split="train" if train else "test", transform=transform, download=download, **kwargs)
assert split in ("train", "valid", "test"), f"Only `train` and `valid` and `test` split available for {dataset_name}"
ds = Country211(root=root, split=split, transform=transform, download=download, **kwargs)
ds.classes = classnames["country211"]
elif dataset_name == "pcam":
assert split in ("train", "val", "test"), f"Only `train` and `val` and `test` split available for {dataset_name}"
# Dead link. Fixed by this PR on torchvision https://github.com/pytorch/vision/pull/5645
# TODO figure out minimal torchvision version needed
ds = PCAM(root=root, split="train" if train else "test", transform=transform, download=download, **kwargs)
ds = PCAM(root=root, split=split, transform=transform, download=download, **kwargs)
ds.classes = classnames["pcam"]
elif dataset_name == "renderedsst2":
ds = RenderedSST2(root=root, split="train" if train else "test", transform=transform, download=download, **kwargs)
assert split in ("train", "val", "test"), f"Only `train` and `val` and `test` split available for {dataset_name}"
ds = RenderedSST2(root=root, split=split, transform=transform, download=download, **kwargs)
elif dataset_name == "fer2013":
assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}"
# Downloadable from https://www.kaggle.com/datasets/msambare/fer2013
# `kaggle datasets download -d msambare/fer2013`
if not os.path.exists(root):
Expand Down
Loading