-
Notifications
You must be signed in to change notification settings - Fork 0
/
create_test_set_selection.py
57 lines (41 loc) · 1.99 KB
/
create_test_set_selection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import argparse
import json
import os
import numpy as np
import config
from dataloaders.datasets import get_dataloaders
def create_test_set_selection(dataset_name: str, output_dir: str, seed: int = 1994):
"""
Create and save random test set selections for each dataset.
Args:
dataset_name (str): Name of the dataset, must be one of `config.dataset_names`.
output_dir (str): Directory where experiment outputs and metadata will be stored.
seed (int): Seed for random number generation. Default is 1994.
"""
assert dataset_name in config.dataset_names
# Set random seed
np.random.seed(seed)
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
# Create metadata directory
metadata_dir = os.path.join(output_dir, "metadata")
os.makedirs(metadata_dir, exist_ok=True)
# Get the test dataloader
_, _, test_loader = get_dataloaders(data_root_dir="Data/", dataset_name=dataset_name, batch_size=1)
# Get the total number of samples in the test set
total_samples = len(test_loader.dataset)
print(f"* {dataset_name=}, {total_samples=}")
# Generate random indices
indices = np.arange(total_samples).tolist()
np.random.shuffle(indices)
# Save indices to JSON file
json_filename = os.path.join(metadata_dir, f"{dataset_name}_test_indices.json")
with open(json_filename, "w") as json_file:
json.dump({"indices": indices}, json_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Create test set selections for datasets.")
parser.add_argument("--dataset_name", type=str, choices=config.dataset_names, help="Name of the dataset")
parser.add_argument("--output_dir", type=str, help="Directory for experiment outputs and metadata.")
parser.add_argument("--seed", type=int, default=1994, help="Seed for random number generation.")
args = parser.parse_args()
create_test_set_selection(args.dataset_name, args.output_dir, args.seed)