forked from damaggu/cellSAM_devel
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcheck_classifier_json.py
89 lines (68 loc) · 2.71 KB
/
check_classifier_json.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import json
import argparse
import numpy as np
import pandas as pd
def reverse_dict(dict):
reversed_dict = {}
for key, value in dict.items():
if isinstance(value, list):
for item in value:
reversed_dict[item] = key
else:
reversed_dict[value] = key
return reversed_dict
def compute_acc(dict, annotations):
correct_cnt = 0
for img in list(annotations[0]):
pred = dict[img]
if pred == 'no_wsi':
pred = 1
elif pred == 'flagged':
pred = 1
elif pred == 'regular':
pred = 0
gt = annotations[annotations[0] == img][1].values[0]
if gt == pred:
correct_cnt += 1
acc = correct_cnt / len(list(annotations[0]))
return acc
if __name__ == "__main__":
# argsparse
parser = argparse.ArgumentParser()
parser.add_argument('--ann_type', type=str, default='tune', help='tune or hidden')
parser.add_argument('--json_folder', type=str, default='results_tune_test2', help='folder name')
args = parser.parse_args()
# ANNOTATION TYPE
ANN_TYPE = args.ann_type
# get annotation values
## load labels from csv files
tune_csv_file = 'bloodcellSheet_tuning.csv'
hidden_csv_file = 'bloodcellSheet_hidden.csv'
tune_annotations = pd.read_csv(tune_csv_file, header=None)
tune_annotations[1] = pd.to_numeric(tune_annotations[1], errors='coerce').fillna(0).astype(int)
tune_annotations[0] = tune_annotations[0].str.split('.').str[0]
tune_annotations = tune_annotations.sort_values(by=0, axis=0)
hidden_annotations = pd.read_csv(hidden_csv_file, header=None)
hidden_annotations[1] = pd.to_numeric(hidden_annotations[1], errors='coerce').fillna(0).astype(int)
hidden_annotations[0] = hidden_annotations[0].str.split('.').str[0]
hidden_annotations = hidden_annotations.sort_values(by=0, axis=0)
# load the json file
joson_path_root = 'tmp_outs/'
json_file_name = 'wsi_imgs_dict.json'
# json_folder = 'results_hidden_test2'
# json_folder = 'results_tune_test2_nopreproc'
json_folder = args.json_folder
json_path = joson_path_root + json_folder + '/' + json_file_name
with open(json_path) as f:
data = json.load(f)
# get the reverse dict
reversed_dict = reverse_dict(data)
reverse_dict = {key.split('.')[0]: value for key, value in reversed_dict.items()}
# compute the accuracy
if ANN_TYPE == 'tune':
acc = compute_acc(reverse_dict, tune_annotations)
elif ANN_TYPE == 'hidden':
acc = compute_acc(reverse_dict, hidden_annotations)
else:
raise ValueError('ANN_TYPE should be either tune or hidden')
print(f'{ANN_TYPE} accuracy: {acc}')