-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_validator.py
234 lines (190 loc) · 12.3 KB
/
model_validator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import pandas as pd
from sklearn.model_selection import GridSearchCV
from sklearn.svm import LinearSVC
from sklearn.metrics import roc_curve, roc_auc_score, confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
from dimension_manager import get_top_usage_ratio_dimensions, get_top_ratio_usage_dimensions, \
get_top_usage_ratio_dimensions_by_types, get_top_ratio_usage_dimensions_by_types
import itertools
from vector_space_construction import get_vector_space_dict, get_X_y
from utils import find_closest_values
# for the grid search
N_JOBS = 1
PRE_DISPATCH = '2*n_jobs'
def set_const_n_jobs(n_jobs: int):
global N_JOBS
N_JOBS = n_jobs
def set_const_pre_dispatch(pre_dispatch):
global PRE_DISPATCH
PRE_DISPATCH = pre_dispatch
def get_best_params(X_train_val, y_train_val, nb_splits, scorer, param_grid, estimator):
grid_search = GridSearchCV(estimator=estimator, param_grid=param_grid, cv=nb_splits, scoring=scorer,
n_jobs=N_JOBS, pre_dispatch=PRE_DISPATCH)
grid_search.fit(X_train_val, y_train_val)
return {'best_params': grid_search.best_params_,
'best_score': grid_search.best_score_,
'best_estimator': grid_search.best_estimator_,
'cv_results': pd.DataFrame(grid_search.cv_results_)}
def get_best_LinearSVC_param(X_train_val, y_train_val, nb_splits=10, scorer='roc_auc',
c_range=np.logspace(-5, 2, num=100)):
grid_search_dict = get_best_params(X_train_val, y_train_val,
nb_splits, scorer,
{'C': c_range}, LinearSVC())
cv_data_frame = grid_search_dict['cv_results']
grid_search_dict['cv_results'] = pd.DataFrame({'param_C': cv_data_frame['param_C'],
'mean_test_score': cv_data_frame['mean_test_score'],
'rank_test_score': cv_data_frame['rank_test_score']})
return grid_search_dict
def get_best_top_usage_ratio_LinearSVC_param(apks_train_val, nb_splits=10, scorer='roc_auc',
c_range=np.logspace(-5, 2, num=100),
limit_range=np.append(np.arange(100, 1001, 100), [None]),
malware_usage_range=np.linspace(0.1, 1, 10),
malware_ratio_range=np.linspace(0.1, 1, 10)):
return get_best_reduced_dimensions_LinearSVC_param(apks_train_val=apks_train_val,
dimension_retriever=get_top_usage_ratio_dimensions,
nb_splits=nb_splits, scorer=scorer, c_range=c_range,
limit_range=limit_range, malware_ratio_range=malware_ratio_range,
malware_usage_range=malware_usage_range)
def get_best_top_ratio_usage_LinearSVC_param(apks_train_val, nb_splits=10, scorer='roc_auc',
c_range=np.logspace(-5, 2, num=100),
limit_range=np.append(np.arange(100, 1001, 100), [None]),
malware_usage_range=np.linspace(0.1, 1, 10),
malware_ratio_range=np.linspace(0.1, 1, 10)):
return get_best_reduced_dimensions_LinearSVC_param(apks_train_val=apks_train_val,
dimension_retriever=get_top_ratio_usage_dimensions,
nb_splits=nb_splits, scorer=scorer, c_range=c_range,
limit_range=limit_range, malware_ratio_range=malware_ratio_range,
malware_usage_range=malware_usage_range)
def get_best_top_usage_ratio_by_types_LinearSVC_param(apks_train_val, nb_splits=10, scorer='roc_auc',
c_range=np.logspace(-5, 2, num=100),
limit_range=np.append(np.arange(100, 1001, 100), [None]),
malware_usage_range=np.linspace(0.1, 1, 10),
malware_ratio_range=np.linspace(0.1, 1, 10)):
return get_best_reduced_dimensions_LinearSVC_param(apks_train_val=apks_train_val,
dimension_retriever=get_top_usage_ratio_dimensions_by_types,
nb_splits=nb_splits, scorer=scorer, c_range=c_range,
limit_range=limit_range, malware_ratio_range=malware_ratio_range,
malware_usage_range=malware_usage_range)
def get_best_top_ratio_usage_by_types_LinearSVC_param(apks_train_val, nb_splits=10, scorer='roc_auc',
c_range=np.logspace(-5, 2, num=100),
limit_range=np.append(np.arange(100, 1001, 100), [None]),
malware_usage_range=np.linspace(0.1, 1, 10),
malware_ratio_range=np.linspace(0.1, 1, 10)):
return get_best_reduced_dimensions_LinearSVC_param(apks_train_val=apks_train_val,
dimension_retriever=get_top_ratio_usage_dimensions_by_types,
nb_splits=nb_splits, scorer=scorer, c_range=c_range,
limit_range=limit_range, malware_ratio_range=malware_ratio_range,
malware_usage_range=malware_usage_range)
def get_best_reduced_dimensions_LinearSVC_param(apks_train_val, dimension_retriever,
nb_splits=10, scorer='roc_auc',
c_range=np.logspace(-5, 2, num=100),
limit_range=np.append(np.arange(100, 1001, 100), [None]),
malware_usage_range=np.linspace(0.1, 1, 10),
malware_ratio_range=np.linspace(0.1, 1, 10)):
best_grid_search_result = {
'best_params': None,
'best_score': 0,
'best_estimator': None,
'cv_results': None
}
for limit, malware_usage, malware_ratio in itertools.product(limit_range,
malware_usage_range, malware_ratio_range):
dimensions = dimension_retriever(limit=limit, malware_usage=malware_usage,
malware_ratio=malware_ratio)
if len(dimensions) == 0: # si pas de dimensions passer à la prochaine itération
continue
vector_space_dict = get_vector_space_dict(dimensions)
X_train_val, y_train_val = get_X_y(vector_space_dict, apks_train_val)
grid_search_result = get_best_LinearSVC_param(X_train_val, y_train_val, nb_splits=nb_splits,
scorer=scorer, c_range=c_range)
if grid_search_result['best_score'] > best_grid_search_result['best_score']:
best_grid_search_result = grid_search_result
best_grid_search_result['best_params'] = {'C': best_grid_search_result['best_params']['C'],
'limit': limit,
'malware_usage': malware_usage,
'malware_ratio': malware_ratio}
return best_grid_search_result
def get_plot(x_values, y_values, line_label, x_label, y_label, linewidth):
get_plots([x_values], [y_values], [line_label], ['r'], x_label, y_label, linewidth=linewidth)
def get_plots(x_values, y_values, line_labels, colors, x_label, y_label, linewidth):
for x_value, y_value, line_label, color in zip(x_values, y_values, line_labels, colors):
plt.plot(x_value, y_value, color, label=line_label, linewidth=linewidth)
plt.xlabel(x_label)
plt.ylabel(y_label)
def get_plot_with_markers(x_values, y_values, line_label, x_label, y_label, linewidth, markers=None):
plt.plot(x_values, y_values, color='r', label=line_label, linewidth=linewidth)
if markers is not None:
for xy in markers:
plt.annotate('(%2.2f, %2.2f)' % xy, xy=xy, textcoords='data')
plt.xlabel(x_label)
plt.ylabel(y_label)
def get_plot_linearSVC_roc_curve(linearSVC: 'LinearSVC', X_test, y_test, line_label, linewidth):
fprs, tprs, _ = roc_curve(y_test, linearSVC.decision_function(X_test))
auc_score = roc_auc_score(y_test, linearSVC.decision_function(X_test))
get_plot(fprs, tprs,
line_label + ', AUC= ' + str(auc_score),
'FPR', 'TPR', linewidth=linewidth)
def get_plot_linearSVC_class_separation(linearSVC: 'LinearSVC', X_test, y_test):
y_decision_score = linearSVC.decision_function(X_test)
y_positive_decision_score = y_decision_score[y_test == 1]
y_negative_decision_score = y_decision_score[y_test == 0]
if len(y_positive_decision_score) != 0:
_, min_positive = np.modf(y_positive_decision_score.min() - 1)
_, max_positive = np.modf(y_positive_decision_score.max() + 1)
positive_bins = np.arange(min_positive, max_positive + 1)
plt.hist(y_positive_decision_score, bins=positive_bins, alpha=0.5, label='True positives', color='b', edgecolor='black')
if len(y_negative_decision_score) != 0:
_, min_negative = np.modf(y_negative_decision_score.min() - 1)
_, max_negative = np.modf(y_negative_decision_score.max() + 1)
negative_bins = np.arange(min_negative, max_negative + 1)
plt.hist(y_negative_decision_score, bins=negative_bins, alpha=0.5, label='True negatives', color='r', edgecolor='black')
plt.xlabel('SVM decision_function values')
plt.ylabel('Number of data points')
def get_plot_linearSVC_fpr_per_threshold(linearSVC: 'LinearSVC', X_test, y_test, line_label, linewidth,
marked_thresholds: list = None):
fprs, _, thresholds = roc_curve(y_test, linearSVC.decision_function(X_test))
xy = None
if marked_thresholds is not None:
marked_thresholds = find_closest_values(thresholds, marked_thresholds)
marked_fprs = fprs[np.where([True if thr in marked_thresholds else False for thr in thresholds])]
xy = list(zip(marked_thresholds, marked_fprs))
get_plot_with_markers(thresholds, fprs, line_label, 'Threshold values', 'FPR',
linewidth=linewidth, markers=xy)
def get_plot_linearSVC_tpr_per_threshold(linearSVC: 'LinearSVC', X_test, y_test, line_label, linewidth,
marked_thresholds: list = None):
_, tprs, thresholds = roc_curve(y_test, linearSVC.decision_function(X_test))
xy = None
if marked_thresholds is not None:
marked_thresholds = find_closest_values(thresholds, marked_thresholds)
marked_tprs = tprs[np.where([True if thr in marked_thresholds else False for thr in thresholds])]
xy = list(zip(marked_thresholds, marked_tprs))
get_plot_with_markers(thresholds, tprs, line_label, 'Threshold values', 'TPR',
linewidth=linewidth, markers=xy)
def get_linearSVC_validation_report(linearSVC: 'LinearSVC', X_test, y_test):
auc = roc_auc_score(y_test, linearSVC.decision_function(X_test))
conf_matrix = confusion_matrix(y_test, linearSVC.predict(X_test), labels=[0, 1])
tn = conf_matrix[0, 0]
tp = conf_matrix[1, 1]
fp = conf_matrix[0, 1]
fn = conf_matrix[1, 0]
accuracy = (tp + tn) / (tp + tn + fp + fn)
recall = tp / (tp + fn)
precision = tp / (tp + fp)
fpr = fp / (fp + tn)
f1 = 2 * (precision * recall) / (precision + recall)
return {
'positives': tp + fn,
'negatives': tn + fp,
'auc': auc,
'accuracy': accuracy,
'recall': recall,
'precision': precision,
'f1': f1,
'tp': tp,
'tn': tn,
'fp': fp,
'fn': fn,
'tpr': recall,
'fpr': fpr
}