forked from shift-happens-benchmark/icml-2022
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimagenet_r.py
81 lines (65 loc) · 2.67 KB
/
imagenet_r.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""Example for a Shift Happens task: ImageNet-R"""
import dataclasses
import os
import numpy as np
import torchvision.datasets as tv_datasets
import torchvision.transforms as tv_transforms
import shifthappens.data.base as sh_data
import shifthappens.data.torch as sh_data_torch
import shifthappens.utils as sh_utils
from shifthappens import benchmark as sh_benchmark
from shifthappens.data.base import DataLoader
from shifthappens.models import base as sh_models
from shifthappens.models.base import PredictionTargets
from shifthappens.tasks.base import Task
from shifthappens.tasks.metrics import Metric
from shifthappens.tasks.task_result import TaskResult
@sh_benchmark.register_task(
name="ImageNet-R", relative_data_folder="imagenet_r", standalone=True
)
@dataclasses.dataclass
class ImageNetR(Task):
resources = [
(
"imagenet-r.tar",
"https://people.eecs.berkeley.edu/~hendrycks/imagenet-r.tar",
"A61312130A589D0CA1A8FCA1F2BD3337",
)
]
def setup(self):
dataset_folder = os.path.join(self.data_root, "imagenet-r")
if not os.path.exists(dataset_folder):
# download data
for file_name, url, md5 in self.resources:
sh_utils.download_and_extract_archive(
url, self.data_root, md5, file_name
)
test_transform = tv_transforms.Compose(
[
tv_transforms.ToTensor(),
tv_transforms.Lambda(lambda x: x.permute(1, 2, 0)),
]
)
self.ch_dataset = tv_datasets.ImageFolder(
root=dataset_folder, transform=test_transform
)
self.images_only_dataset = sh_data_torch.IndexedTorchDataset(
sh_data_torch.ImagesOnlyTorchDataset(self.ch_dataset)
)
def _prepare_dataloader(self) -> DataLoader:
return sh_data.DataLoader(self.images_only_dataset, max_batch_size=None)
def _evaluate(self, model: sh_models.Model) -> TaskResult:
dataloader = self._prepare_dataloader()
all_predicted_labels_list = []
for predictions in model.predict(
dataloader, PredictionTargets(class_labels=True)
):
all_predicted_labels_list.append(predictions.class_labels)
all_predicted_labels = np.concatenate(all_predicted_labels_list, 0)
accuracy = all_predicted_labels == np.array(self.ch_dataset.targets)
return TaskResult(
accuracy=accuracy, summary_metrics={Metric.Robustness: "accuracy"}
)
if __name__ == "__main__":
from shifthappens.models.torchvision import ResNet18
sh_benchmark.evaluate_model(ResNet18(device="cpu", max_batch_size=128), "data")