-
Notifications
You must be signed in to change notification settings - Fork 0
/
metrics.py
95 lines (83 loc) · 2.61 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import numpy as np
import sklearn
from sklearn.metrics import accuracy_score
import torch
def klue_re_micro_f1(preds, labels):
"""KLUE-RE micro f1 (except no_relation)"""
label_list = [
"no_relation",
"org:top_members/employees",
"org:members",
"org:product",
"per:title",
"org:alternate_names",
"per:employee_of",
"org:place_of_headquarters",
"per:product",
"org:number_of_employees/members",
"per:children",
"per:place_of_residence",
"per:alternate_names",
"per:other_family",
"per:colleagues",
"per:origin",
"per:siblings",
"per:spouse",
"org:founded",
"org:political/religious_affiliation",
"org:member_of",
"per:parents",
"org:dissolved",
"per:schools_attended",
"per:date_of_death",
"per:date_of_birth",
"per:place_of_birth",
"per:place_of_death",
"org:founded_by",
"per:religion",
]
no_relation_label_idx = label_list.index("no_relation")
label_indices = list(range(len(label_list)))
label_indices.remove(no_relation_label_idx)
return (
sklearn.metrics.f1_score(labels, preds, average="micro", labels=label_indices)
* 100.0
)
def klue_re_auprc(probs, labels):
"""KLUE-RE AUPRC (with no_relation)"""
probs = np.array(probs)
labels = np.eye(30)[labels]
score = np.zeros((30,))
for c in range(30):
targets_c = labels.take([c], axis=1).ravel()
preds_c = probs.take([c], axis=1).ravel()
precision, recall, _ = sklearn.metrics.precision_recall_curve(
targets_c, preds_c
)
score[c] = sklearn.metrics.auc(recall, precision)
return np.average(score) * 100.0
def compute_metrics(keys, logitss):
"""validation을 위한 metrics function"""
# print(pred.predictions[0])
labels = np.array(keys, dtype=np.int64)
logitss = torch.tensor(logitss)
preds = torch.argmax(logitss, dim=-1)
probs = logitss
# calculate accuracy using sklearn's function
f1 = klue_re_micro_f1(preds, labels)
auprc = klue_re_auprc(probs, labels)
acc = accuracy_score(labels, preds) # 리더보드 평가에는 포함되지 않습니다.
return f1, auprc, acc
class Metrics(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count