Skip to content

Commit

Permalink
Conflits Ok
Browse files Browse the repository at this point in the history
  • Loading branch information
szczepanskiNicolas committed Nov 25, 2023
2 parents e7f8cd8 + 8506264 commit 0a6ae18
Show file tree
Hide file tree
Showing 12 changed files with 216 additions and 27 deletions.
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Changelog

### 1.0.10
- Contrastive for BT classification (binary classes)
- change function name in explainer (unset_specific_features -> unset_excluded_features)

### 1.0.0
- Regression for boosted trees
- Adding thoeries
- Easier import model
- Graphical user interface: displaying, loading, saving explanations
- Data preprocessing
- unit tests
## 0.X
- Initial release
9 changes: 7 additions & 2 deletions pyxai/examples/BT/builder-simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

BTs = Builder.BoostedTrees([tree1, tree2, tree3], n_classes=2)

instance = (4, 3, 1, 1)
instance = (4, 3, 1, 0)
print("instance:", instance)

explainer = Explainer.initialize(BTs, instance)
Expand All @@ -45,7 +45,7 @@
print("direct reason:", direct)
direct_features = explainer.to_features(direct)
print("to_features:", direct_features)
assert direct_features == ('f1 > 2', 'f2 > 1', 'f3 == 1', 'f4 == 1'), "The direct reason is not correct."
#assert direct_features == ('f1 > 2', 'f2 > 1', 'f3 == 1', 'f4 == 1'), "The direct reason is not correct."

print("---------------------------------------------------")
tree_specific = explainer.tree_specific_reason()
Expand All @@ -55,6 +55,11 @@
print("is_tree_specific:", explainer.is_tree_specific_reason(tree_specific))
print("is_sufficient_reason:", explainer.is_sufficient_reason(tree_specific))

print("---------------------------------------------------")
contrastive_reason = explainer.minimal_contrastive_reason()
print("contrastive reason:", explainer.to_features(contrastive_reason))
print("is contrastive: ", explainer.is_contrastive_reason(contrastive_reason))

print("---------------------------------------------------")
sufficient = explainer.sufficient_reason()
print("sufficient reason:", sufficient)
Expand Down
27 changes: 17 additions & 10 deletions pyxai/examples/BT/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

# Explanation part
print(instance)
explainer = Explainer.initialize(model, instance)
explainer = Explainer.initialize(model, instance, features_type={"numerical": Learning.DEFAULT})
direct_reason = explainer.direct_reason()
print("len direct: ", len(direct_reason))
print("is a reason (for 50 checks):", explainer.is_reason(direct_reason, n_samples=50))
Expand All @@ -18,15 +18,22 @@
tree_specific_reason = explainer.tree_specific_reason(n_iterations=10)
print("\nlen tree_specific: ", len(tree_specific_reason))
print("\ntree_specific: ", explainer.to_features(tree_specific_reason, eliminate_redundant_features=True))
instances = learner.get_instances(n=100)

print(instances)
for inst, p in instances:
explainer.set_instance(inst)
direct_reason = explainer.direct_reason()

tree_specific_reason = explainer.tree_specific_reason(n_iterations=100)
print("is a tree specific", explainer.is_tree_specific_reason(tree_specific_reason))

explainer.set_excluded_features(["score_factor"])
contrastive_reason = explainer.minimal_contrastive_reason(n=2)
print("\n\ncontrastive reason: ", explainer.to_features(contrastive_reason, contrastive=True))
print("is contrastive: ", explainer.is_contrastive_reason(contrastive_reason))
print("elapsed time: ", explainer.elapsed_time)
print()
#instances = learner.get_instances(n=100)

#print(instances)
#for inst, p in instances:
# explainer.set_instance(inst)
# direct_reason = explainer.direct_reason()
#
# tree_specific_reason = explainer.tree_specific_reason(n_iterations=100)
# print("is a tree specific", explainer.is_tree_specific_reason(tree_specific_reason))

#explainer.show()
#minimal_tree_specific_reason = explainer.minimal_tree_specific_reason(time_limit=20)
Expand Down
4 changes: 2 additions & 2 deletions pyxai/sources/core/explainer/Explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def set_excluded_features(self, excluded_features):
@param excluded_features (list[str] | tuple[str]): the features names to be excluded
"""
if len(excluded_features) == 0:
self.unset_specific_features()
self.unset_excluded_features()
return
self._excluded_features = excluded_features
if self.instance is None:
Expand All @@ -443,7 +443,7 @@ def _set_specific_features(self, specific_features): # TODO a changer en je veu
self.set_excluded_features(excluded)


def unset_specific_features(self):
def unset_excluded_features(self):
"""
Unset the features set with the set_excluded_features method.
"""
Expand Down
29 changes: 22 additions & 7 deletions pyxai/sources/core/explainer/explainerBT.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pyxai.sources.solvers.CSP.AbductiveV1 import AbductiveModelV1
from pyxai.sources.solvers.CSP.TSMinimalV2 import TSMinimal
from pyxai.sources.solvers.GRAPH.TreeDecomposition import TreeDecomposition
from pyxai.sources.solvers.CPLEX.ContrastiveBT import ContrastiveBT


class ExplainerBT(Explainer):
Expand All @@ -36,12 +37,13 @@ def set_instance(self, instance):
def _to_binary_representation(self, instance):
return self._boosted_trees.instance_to_binaries(instance)


def _theory_clauses(self):
return self._boosted_trees.get_theory(self._binary_representation)


def is_implicant(self, abductive):

if self._boosted_trees.n_classes == 2:
# 2-classes case
sum_weights = []
Expand All @@ -55,7 +57,7 @@ def is_implicant(self, abductive):

return self.target_prediction == prediction
else:

# multi-classes case
worst_one = self.compute_weights_class(abductive, self.target_prediction, king="worst")
best_ones = [self.compute_weights_class(abductive, cl, king="best") for cl
Expand Down Expand Up @@ -179,7 +181,7 @@ def sufficient_reason(self, *, n=1, seed=0, time_limit=None):
"""
if self._instance is None:
raise ValueError("Instance is not set")

raise NotImplementedError("In progress")
assert n == 1, "To do implement that"
if self._boosted_trees.n_classes > 2:
raise NotImplementedError
Expand Down Expand Up @@ -241,8 +243,6 @@ def tree_specific_reason(self, *, n_iterations=50, time_limit=None, seed=0, hist
The method used (in c++), for a given seed, compute several tree specific reasons and return the best.
For that, the algorithm is executed either during a given time or or until a certain number of reasons is calculated.
The parameter 'reason_expressivity' have to be fixed either by ReasonExpressivity.Features or ReasonExpressivity.Conditions.
Args:
n_iterations (int, optional): _description_. Defaults to 50.
time_limit (int, optional): _description_. Defaults to None.
Expand All @@ -262,10 +262,10 @@ def tree_specific_reason(self, *, n_iterations=50, time_limit=None, seed=0, hist
if self.c_BT is None:
# Preprocessing to give all trees in the c++ library
self.c_BT = c_explainer.new_classifier_BT(self._boosted_trees.n_classes)

for tree in self._boosted_trees.forest:
c_explainer.add_tree(self.c_BT, tree.raw_data_for_CPP())

c_explainer.set_excluded(self.c_BT, tuple(self._excluded_literals))
if self._theory:
c_explainer.set_theory(self.c_BT, tuple(self._boosted_trees.get_theory(self._binary_representation)))
Expand Down Expand Up @@ -351,6 +351,21 @@ def is_tree_specific_reason(self, reason, check_minimal_inclusion=False):
return False
return True


def minimal_contrastive_reason(self, *, n=1, time_limit=None):
if self._instance is None:
raise ValueError("Instance is not set")
if self._boosted_trees.n_classes > 2:
raise NotImplementedError("Minimal contrastive reason is not implemented for the multi class case")

starting_time = -time.process_time()
contrastive_bt = ContrastiveBT()
c = contrastive_bt.create_model_and_solve(self, None if self._theory == False else self._theory_clauses(), self._excluded_literals, 1, time_limit)
time_used = starting_time + time.process_time()
self._elapsed_time = time_used if time_limit is None or time_used < time_limit else Explainer.TIMEOUT
self.add_history(self._instance, self.__class__.__name__, self.minimal_contrastive_reason.__name__, c)
return c

# def check_sufficient(self, reason, n_samples=1000):
# """
# Check if the ''reason'' is abductive and check if the reasons with one selected literal in less are not abductives. This allows to check
Expand Down
3 changes: 2 additions & 1 deletion pyxai/sources/core/explainer/explainerRF.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def minimal_contrastive_reason(self, *, n=1, time_limit=None):
"""
if self._instance is None:
raise ValueError("Instance is not set")

if self._random_forest.n_classes > 2:
raise NotImplementedError("Minimal contrastive reason is not implemented for the multi class case")
n = n if type(n) == int else float('inf')
first_call = True
time_limit = 0 if time_limit is None else time_limit
Expand Down
17 changes: 15 additions & 2 deletions pyxai/sources/core/structure/binaryMapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,21 @@ def instance_to_binaries(self, instance, preference_order=None):

def get_id_features(self, binary_representation):
return tuple(self.map_id_binaries_to_features[abs(lit)][0] for lit in binary_representation)




# Return a map 'dict_id_features_to_id_binaries': : dict[id_feature] -> [list of id_binaries]
def get_id_binaries(self):
dict_id_features_to_id_binaries = dict()
for key in self.map_features_to_id_binaries.keys():
id_feature, _, _ = key
id_binary = self.map_features_to_id_binaries[key][0]
if id_feature in dict_id_features_to_id_binaries.keys():
dict_id_features_to_id_binaries[id_feature].append(id_binary)
else:
dict_id_features_to_id_binaries[id_feature] = [id_binary]
return dict_id_features_to_id_binaries


def convert_features_to_dict_features(self, features, feature_names):
dict_features = dict()
dict_features_categorical = dict()
Expand Down
125 changes: 125 additions & 0 deletions pyxai/sources/solvers/CPLEX/ContrastiveBT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from ortools.linear_solver import pywraplp
from pyxai.sources.core.structure.type import TypeLeaf
from pyxai.sources.core.explainer.Explainer import Explainer


class ContrastiveBT:
def __init__(self):
pass


def create_model_and_solve(self, explainer, theory, excluded, n, time_limit):
forest = explainer._boosted_trees.forest
leaves = [tree.get_leaves() for tree in forest]
bin_len = len(explainer.binary_representation)
solver = pywraplp.Solver.CreateSolver("SCIP")
features_to_bin = explainer._boosted_trees.get_id_binaries()

if time_limit is not None:
solver.SetTimeLimit(time_limit * 1000) # time limit in milisecond

# Model variables
instance = [solver.BoolVar(f"x[{i}]") for i in range(bin_len)] # The instance

active_leaves = []
for j, tree in enumerate(forest):
active_leaves.append([solver.BoolVar(f"y[{j}][{i}]") for i in range(len(tree.get_leaves()))]) # Actives leaves

flipped = [solver.BoolVar(f"z[{i}]") for i in range(bin_len)] # The flipped variables

# Constraints related to tree structure
for j, tree in enumerate(forest):
for i, leave in enumerate(tree.get_leaves()):
t = TypeLeaf.LEFT if leave.parent.left == leave else TypeLeaf.RIGHT
cube = forest[j].create_cube(leave.parent, t)
nb_neg = sum((1 for l in cube if l < 0))
nb_pos = sum((1 for l in cube if l > 0))
constraint = solver.RowConstraint(-solver.infinity(), nb_neg)
constraint.SetCoefficient(active_leaves[j][i], nb_pos + nb_neg)
for l in cube:
constraint.SetCoefficient(instance[abs(l) - 1], -1 if l > 0 else 1)

# Only one leave activated per tree
for j, tree in enumerate(forest):
constraint = solver.RowConstraint(1, 1)
for v in active_leaves[j]:
constraint.SetCoefficient(v, 1)

# Change the prediction
if explainer.target_prediction == 1:
constraint_target = solver.RowConstraint(-solver.infinity(), 0)
else:
constraint_target = solver.RowConstraint(0, solver.infinity())
for j, tree in enumerate(forest):
for i, leave in enumerate(tree.get_leaves()):
constraint_target.SetCoefficient(active_leaves[j][i], leave.value)

# Constraints related to theory
if theory is not None:
for clause in theory:
constraint = solver.RowConstraint(-solver.infinity(), 0)
for l in clause:
constraint.SetCoefficient(instance[abs(l) - 1], 1 if l < 0 else -1)

# links between instance and flipped
for i in range(bin_len):
const1 = solver.RowConstraint(-solver.infinity(), 1 if explainer.binary_representation[i] > 0 else 0)
const1.SetCoefficient(instance[i], 1)
const1.SetCoefficient(flipped[i], -1)
const2 = solver.RowConstraint(-solver.infinity(), -1 if explainer.binary_representation[i] > 0 else 0)
const2.SetCoefficient(instance[i], -1)
const2.SetCoefficient(flipped[i], -1)

# Set excluded features
for lit in excluded:
constraint = solver.RowConstraint(0, 0)
constraint.SetCoefficient(flipped[abs(lit) - 1], 1)


if theory is None: # the same encoding for RF : if theory minimal wrt features else wrt bin...
# TODO : let the possibilit for the user to choose
# Objective function
objective = solver.Objective()
for i in range(bin_len):
objective.SetCoefficient(flipped[i], 1)
objective.SetMinimization()
else:
# links between features and flipped
dist_features = [solver.BoolVar(f"fd{i}") for i in range(len(features_to_bin))]
i = 0
for f, binaries in features_to_bin.items():
constraint = solver.RowConstraint(-solver.infinity(), 0)
constraint.SetCoefficient(dist_features[i], -1)
for lit in binaries:
constraint.SetCoefficient(flipped[abs(lit -1)], 1 / len(binaries))
i = i + 1
# Objective function
objective = solver.Objective()
for d in dist_features:
objective.SetCoefficient(d, 1)
objective.SetMinimization()


# print(solver.ExportModelAsLpFormat(obfuscated=False))

# Solve the problem and extract n solutions
results = []
first = True
best_objective = -1
while True:
if first:
status = solver.Solve()
else:
status = solver.NextSolution()
if status not in [pywraplp.Solver.OPTIMAL, pywraplp.Solver.FEASIBLE]:
break
solution = [explainer.binary_representation[i] for i in range(len(flipped)) if flipped[i].solution_value() >= 0.5]
if first:
best_objective = len(solution)
first = False
if len(solution) > best_objective:
break
results.append(solution)
if len(results) == n:
break
return Explainer.format(results, n)
2 changes: 1 addition & 1 deletion pyxai/tests/compas.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
0,1,2,3,4,5,6,7,8,9,10,11
Number_of_Priors,score_factor,Age_Above_FourtyFive,Age_Below_TwentyFive,Origin_African_American,Origin_Asian,Origin_Hispanic,Origin_Native_American,Origin_Other,Female,Misdemeanor,Two_yr_Recidivism
0,0,1,0,0,0,0,0,1,0,0,0
0,0,0,0,1,0,0,0,0,0,0,1
4,0,0,1,1,0,0,0,0,0,0,1
Expand Down
8 changes: 8 additions & 0 deletions pyxai/tests/explaining/bt.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ def test_excluded(self):
tree_specific_reason = explainer.tree_specific_reason()
self.assertFalse(explainer.reason_contains_features(tree_specific_reason, 'Female'))

def test_contrastive(self):
learner, model = self.init()
explainer = Explainer.initialize(model)
instances = learner.get_instances(model, n=5)
for instance, prediction in instances:
explainer.set_instance(instance)
contrastive_reason = explainer.minimal_contrastive_reason()
self.assertTrue(len(contrastive_reason) > 0 and explainer.is_contrastive_reason(contrastive_reason))

if __name__ == '__main__':
unittest.main(verbosity=0)
2 changes: 1 addition & 1 deletion pyxai/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.0.9
1.0.10
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
classifiers=['Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Education'],
packages=find_packages(), # exclude=["problems/g7_todo/"]),
package_dir={'pyxai': 'pyxai'},
install_requires=['lxml', 'numpy', 'wheel', 'pandas', 'termcolor', 'shap', 'wordfreq', 'python-sat[pblib,aiger]', 'xgboost==1.7.3', 'pycsp3', 'matplotlib', 'dill', 'lightgbm', 'docplex'],
install_requires=['lxml', 'numpy', 'wheel', 'pandas', 'termcolor', 'shap', 'wordfreq', 'python-sat[pblib,aiger]', 'xgboost==1.7.3', 'pycsp3', 'matplotlib', 'dill', 'lightgbm', 'docplex', 'ortools'],
extras_require={
"gui": ['pyqt6'],
},
Expand Down

0 comments on commit 0a6ae18

Please sign in to comment.