-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathevalx.py
233 lines (199 loc) · 10.5 KB
/
evalx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import os
import yaml
import json
import torch
import random
import argparse
import numpy as np
from torch import nn
from torch.utils import data
from models import get_model
from loader import get_loader
from utils import AverageMeter, accuracy, Confidence_Diagram, brier_score
import torch.nn.functional as F
import pandas as pd
from sklearn.metrics import log_loss
from optimizers import get_optimizer
from loss import get_loss_function
def eval(cfg, logdir):
# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_classes = cfg['data']['num_classes']
# Setup Model
model_cfg = {}
for item in cfg["model"]:
if cfg["model"][item] is not None:
model_cfg[item] = cfg["model"][item]
model_cfg['calibration'] = cfg['testing']['calibration']
model = get_model(**model_cfg, num_classes=n_classes, device=device).to(device)
model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
model.eval()
# Resume pre-trained model
if cfg["testing"]["resume"] is not None:
if os.path.isfile(cfg["testing"]["resume"]):
checkpoint = torch.load(cfg["testing"]["resume"])
pretrained_dict = checkpoint['state_dict']
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.shape == model_dict[k].shape}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
else:
print("No checkpoint found at '{}'".format(cfg["testing"]["resume"]))
ECE = []
NLL = []
BRIER = []
ACC = []
#================================ AutoTune NLL ==========================================
# Calculate the exponential mapping
if cfg['testing']['exponential_map']:
print('Calculating mean and std')
v_loader = get_loader(cfg,'val')
valloader = data.DataLoader(v_loader,
batch_size=cfg["testing"]["batch_size"],num_workers=cfg["testing"]["n_workers"])
with torch.no_grad():
norms = None
for image, _ in valloader:
x_norms = model(image.to(device), is_feature_norm = True)
if norms == None:
norms = x_norms
else:
norms = torch.cat((x_norms, norms), dim=0)
x_mu = torch.mean(norms)
x_std = torch.std(norms)
c = -np.log(cfg['testing']['exponential_map'])/(x_mu-x_std).item()
else:
c = None
for held_out in [1,2,3,4,5]:
# Tuning calibration parameters on the validation set
if cfg['testing']['calibration'] != 'none':
print('============================== start auto-tuning ==============================================')
# Initilize data loader and setup optimizer
v_loader = get_loader(cfg,'val',held_out=held_out,calibration = True)
valloader = data.DataLoader(
v_loader,
batch_size=cfg["testing"]["batch_size"],
num_workers=cfg["testing"]["n_workers"]
)
optimizer_cls = get_optimizer(cfg["testing"]["optimizer"])
optimizer_params = {k: v for k, v in cfg["testing"]["optimizer"].items() if k != "name"}
# Set optimizable parameters
if cfg['testing']['calibration'] == 'alpha-beta':
optimizer = optimizer_cls(nn.ModuleList([model.module.decomp_a,model.module.decomp_b]).parameters(), **optimizer_params)
else:
optimizer = optimizer_cls(model.module.temp.parameters(), **optimizer_params)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg['testing']['tune_epoch'])
loss_fn = get_loss_function("CrossEntropy")
for epoch in range(cfg['testing']['tune_epoch']):
for i,(image, target) in enumerate(valloader):
image = image.to(device)
target = target.to(device)
logit,_,_,_ = model(image, grad=False)
loss = loss_fn(logit, target)
# ODIR regularisation https://arxiv.org/pdf/1910.12656.pdf
if cfg['testing']['calibration'] == 'matrix' or cfg['testing']['calibration'] == 'dirichlet':
weight_reg = 1e-7 * torch.norm(model.module.temp.weight * (1 - torch.eye(n_classes).to(device)))/(n_classes *(n_classes-1))
bais_reg = 1e-7 * torch.norm(model.module.temp.bias)/n_classes
loss = loss + weight_reg + bais_reg
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
print("Auto Tune (NLL) epoch:{}, lr:{lr:.5f}, loss:{loss:.3f}".format(epoch,lr=optimizer.param_groups[-1]['lr'], loss=loss.item()))
print('============================== end auto-tuning ==============================================')
#================================ Evaluation ==========================================
print('============================== start evaluation ==============================================')
# Ground truth and prediction lists
gt_list = []
pred_list = []
# Calibration metrics: ECE, NLL, Brier
ece_list = []
nll_list = []
brier_list = []
if 'degredation' in cfg['testing'] and cfg['testing']['degredation']['type'] is not None:
types = cfg['testing']['degredation']['type']
noise_levels = cfg['testing']['degredation']['value']
else:
types = [None]
noise_levels = [None]
with torch.no_grad():
for noise_type in types:
for level in noise_levels:
# Initilize accuracy and calibration metrics for each condition
eval_top1 = AverageMeter('Acc@1', ':6.2f')
eval_top5 = AverageMeter('Acc@5', ':6.2f')
nll_avg_meter = AverageMeter('negative_log_likelihood', ':6.2f')
brier_avg_meter = AverageMeter('mutual_info', ':6.2f')
ece_meter = Confidence_Diagram(n_classes)
# Update validation config file for each condition
v_cfg = cfg
if noise_type is not None:
v_cfg['testing']['degredation']['type'] = noise_type
v_cfg['testing']['degredation']['value'] = level
v_loader = get_loader(v_cfg, 'test', held_out=held_out, calibration=False)
valloader = data.DataLoader(
v_loader,
batch_size=cfg["testing"]["batch_size"],
num_workers=cfg["testing"]["n_workers"]
)
# Start looping through valloader
for i, (image, target) in enumerate(valloader):
image = image.to(device)
target = target.to(device)
batch_sz = image.shape[0]
logit, _, _, _ = model(image, c = c)
pred_dist = torch.nn.functional.softmax(logit,dim=1)
# Save ground truth and predited class
gt_list.extend(target.cpu().numpy())
pred_list.extend(pred_dist.argmax(1).cpu().numpy())
# Save accuracy
acc1, acc5 = accuracy(pred_dist, target, topk=(1, 5))
eval_top1.update(acc1[0], batch_sz)
eval_top5.update(acc5[0], batch_sz)
# Calculate and save expected calibration error
ece_meter.aggregate_stats(pred_dist,target)
# Calculate and save negative log likelihood
nll_avg_meter.update(log_loss(target.cpu(),pred_dist.cpu(), labels = list(range(n_classes))))
# Calculate and save brier score
brier_avg_meter.update(brier_score(target,pred_dist).mean().cpu())
if i % 100 == 0:
output = ('Test: [{0}/{1}]\t'
'Prec@1 {top1.avg:.3f}\t'
'Prec@5 {top5.avg:.3f}'.format(i, len(valloader), top1=eval_top1, top5=eval_top5))
print(output)
# Append the averaged ECE, NLL, and Brier for this noise and the corresponding serverity level
ece_meter.compute_ece()
ece_list.append(ece_meter.ece)
nll_list.append(nll_avg_meter.avg.item())
brier_list.append(brier_avg_meter.avg.item())
print('Test {}: \t' 'Prec@1 {top1.avg:.3f}\t' 'ECE {ece:.4f}\t' 'NLL {nll:.4f}\t' 'Brier {brier:.4f}\t'
.format(held_out,top1=eval_top1, ece=np.mean(ece_list), nll=np.mean(nll_list), brier=np.mean(brier_list)))
ECE.append(np.mean(ece_list))
NLL.append(np.mean(nll_list))
BRIER.append(np.mean(brier_list))
ACC.append(eval_top1.avg.item())
calibration_summary = ('Test overal: \t' 'Prec@1 {top1:.3f}\t' 'ECE {ece:.4f}\t' 'NLL {nll:.4f}\t' 'Brier {brier:.4f}\t'
.format(top1=np.mean(ACC), ece=np.mean(ECE), nll=np.mean(NLL), brier=np.mean(BRIER)))
print(calibration_summary)
# Write calibration results to txt file
dir = os.path.join(logdir,cfg['testing']['calibration'])
if not os.path.exists(dir):
os.mkdir(dir)
with open(os.path.join(dir, cfg['testing']['dataset'] + '_calibration.txt'),'w') as file:
file.write(calibration_summary)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="config")
parser.add_argument(
"--config",
nargs="?",
type=str,
default="./configs/resnet18/wide.yaml",
help="Configuration file to use",
)
args = parser.parse_args()
with open(args.config) as fp:
cfg = yaml.load(fp)
logdir = os.path.join("runs", 'test', cfg["data"]["name"], cfg["model"]["arch"], cfg['id'])
if not os.path.exists(logdir):
os.makedirs(logdir, exist_ok=True)
print("RUNDIR: {}".format(logdir))
eval(cfg, logdir)