Skip to content

Commit

Permalink
Improved linear evaluation that achieves better results (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
teasgen authored Dec 1, 2023
1 parent 0c11b17 commit 0652bec
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 97 deletions.
66 changes: 62 additions & 4 deletions clip_benchmark/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def get_parser_args():
parser_eval.add_argument('--dataset', type=str, default="cifar10", nargs="+", help="Dataset(s) to use for the benchmark. Can be the name of a dataset, or a collection name ('vtab', 'vtab+', 'imagenet_robustness', 'retrieval') or path of a text file where each line is a dataset name")
parser_eval.add_argument('--dataset_root', default="root", type=str, help="dataset root folder where the datasets are downloaded. Can be in the form of a template depending on dataset name, e.g., --dataset_root='datasets/{dataset}'. This is useful if you evaluate on multiple datasets.")
parser_eval.add_argument('--split', type=str, default="test", help="Dataset split to use")
parser_eval.add_argument('--test_split', dest="split", action='store', type=str, default="test", help="Dataset split to use")
parser_eval.add_argument('--train_split', type=str, nargs="+", default="train", help="Dataset(s) train split names")
mutually_exclusive = parser_eval.add_mutually_exclusive_group()
mutually_exclusive.add_argument('--val_split', default=None, type=str, nargs="+", help="Dataset(s) validation split names. Mutually exclusive with val_proportion.")
mutually_exclusive.add_argument('--val_proportion', default=None, type=float, nargs="+", help="what is the share of the train dataset will be used for validation part, if it doesn't predefined. Mutually exclusive with val_split")
parser_eval.add_argument('--model', type=str, nargs="+", default=["ViT-B-32-quickgelu"], help="Model architecture to use from OpenCLIP")
parser_eval.add_argument('--pretrained', type=str, nargs="+", default=["laion400m_e32"], help="Model checkpoint name to use from OpenCLIP")
parser_eval.add_argument('--pretrained_model', type=str, default="", nargs="+", help="Pre-trained model(s) to use. Can be the full model name where `model` and `pretrained` are comma separated (e.g., --pretrained_model='ViT-B-32-quickgelu,laion400m_e32'), a model collection name ('openai' or 'openclip_base' or 'openclip_multilingual' or 'openclip_all'), or path of a text file where each line is a model fullname where model and pretrained are comma separated (e.g., ViT-B-32-quickgelu,laion400m_e32). --model and --pretrained are ignored if --pretrained_model is used.")
Expand All @@ -35,6 +40,7 @@ def get_parser_args():
parser_eval.add_argument("--distributed", action="store_true", help="evaluation in parallel")
parser_eval.add_argument('--seed', default=0, type=int, help="random seed.")
parser_eval.add_argument('--batch_size', default=64, type=int)
parser_eval.add_argument('--normalize', default=True, type=bool, help="features normalization")
parser_eval.add_argument('--model_cache_dir', default=None, type=str, help="directory to where downloaded models are cached")
parser_eval.add_argument('--feature_root', default="features", type=str, help="feature root folder where the features are stored.")
parser_eval.add_argument('--annotation_file', default="", type=str, help="text annotation file for retrieval datasets. Only needed for when `--task` is `zeroshot_retrieval`.")
Expand Down Expand Up @@ -123,6 +129,24 @@ def main_eval(base):
# if not, assume it is simply the name of the dataset
datasets.append(name)

train_splits = _as_list(base.train_split)
train_splits = _single_option_to_multiple_datasets(train_splits, datasets, "train_split")
proportions, val_splits = None, None
if base.val_split is not None:
val_splits = _as_list(base.val_split)
val_splits = _single_option_to_multiple_datasets(val_splits, datasets, "val_split")
if base.val_proportion is not None:
proportions = _as_list(base.val_proportion)
proportions = _single_option_to_multiple_datasets(proportions, datasets, "val_proportion")

dataset_info = {}
for i in range(len(datasets)):
dataset_info[datasets[i]] = {
"train_split": train_splits[i],
"val_split": val_splits[i] if val_splits is not None else None,
"proportion": proportions[i] if proportions is not None else None
}

# Get list of languages to evaluate on
languages = _as_list(base.language)

Expand All @@ -145,16 +169,30 @@ def main_eval(base):
args.pretrained = pretrained
args.dataset = dataset
args.language = language
args.train_split = dataset_info[dataset]["train_split"]
args.val_split = dataset_info[dataset]["val_split"]
args.val_proportion = dataset_info[dataset]["proportion"]
run(args)

def _as_list(l):
if not l:
return []
return [l] if type(l) != list else l

def _single_option_to_multiple_datasets(cur_option, datasets, name):
cur_len = len(cur_option)
ds_len = len(datasets)
if cur_len != ds_len:
# If user wants to use same value for all datasets
if cur_len == 1:
return [cur_option[0]] * ds_len
else:
raise ValueError(f"The incommensurable number of {name}")
else:
return cur_option

def run(args):
"""Console script for clip_benchmark."""

if torch.cuda.is_available():
if args.distributed:
local_rank, rank, world_size = world_info_from_env()
Expand Down Expand Up @@ -276,23 +314,41 @@ def run(args):
amp=args.amp,
)
elif task == "linear_probe":
# we also need the train split for linear probing.
# we also need the train and validation splits for linear probing.
train_dataset = None
train_dataset = build_dataset(
dataset_name=args.dataset,
root=dataset_root,
transform=transform,
split='train',
split=args.train_split,
annotation_file=args.annotation_file,
download=True,
)
if args.val_split is not None:
val_dataset = build_dataset(
dataset_name=args.dataset,
root=dataset_root,
transform=transform,
split=args.val_split,
annotation_file=args.annotation_file,
download=True,
)
else:
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [1 - args.val_proportion, args.val_proportion])

train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size,
shuffle=False, num_workers=args.num_workers,
collate_fn=collate_fn, pin_memory=True,
)
val_dataloader = torch.utils.data.DataLoader(
val_dataset, batch_size=args.batch_size,
shuffle=False, num_workers=args.num_workers,
collate_fn=collate_fn, pin_memory=True,
)
metrics = linear_probe.evaluate(
model,
train_dataloader,
train_dataloader,
dataloader,
args.fewshot_k,
args.batch_size,
Expand All @@ -302,7 +358,9 @@ def run(args):
(args.model + '-' + args.pretrained + '-' + args.dataset).replace('/', '_'),
args.seed,
args.feature_root,
val_dataloader=val_dataloader,
device=args.device,
normalize=args.normalize,
amp=args.amp,
verbose=args.verbose,
)
Expand Down
Loading

0 comments on commit 0652bec

Please sign in to comment.