-
Notifications
You must be signed in to change notification settings - Fork 5
/
CtLModel.py
67 lines (46 loc) · 1.95 KB
/
CtLModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.tree import DecisionTreeClassifier
import numpy as np
'''
CtL : Concept-to-Label
Class for the transparent model representing a function from concepts to task labels.
Represents the decision-making process of a given black-box model, in the concept representation
'''
class CtLModel:
def __init__(self, c_data, y_data, **params):
# Create copy of passed-in parameters
self.params = params
if 'method' in self.params:
method = self.params["method"]
else:
method = 'DT'
# Retrieve the classifier type
self.clf_type = method
# Retrieve total number of concepts
self.n_concepts = c_data.shape[1]
# Retrieve the concept names
if "concept_names" in self.params:
self.concept_names = self.params["concept_names"]
else:
self.concept_names = ["Concept " + str(i) for i in range(self.n_concepts)]
# Retrieve the class names
if "class_names" in self.params:
self.class_names = self.params["class_names"]
else:
n_classes = np.max(y_data) + 1
self.class_names = [str(i) for i in range(n_classes)]
# Train classifier for predicting the output labels from concept data
self.clf = self._train_label_classifier(c_data, y_data, self.clf_type)
def _train_label_classifier(self, c_data, y_data, method='DT'):
if method == 'DT':
clf = DecisionTreeClassifier(class_weight='balanced')
elif method == 'LR':
clf = LogisticRegression(max_iter=200, multi_class='auto', solver='lbfgs')
elif method == 'LinearRegression':
clf = LinearRegression()
else:
raise ValueError("Unrecognised model type...")
clf.fit(c_data, y_data)
return clf
def predict(self, c_data):
return self.clf.predict(c_data)