Skip to content

Commit

Permalink
Make it work with latest sklearn tests
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <[email protected]>
  • Loading branch information
adam2392 committed Oct 14, 2024
1 parent 14d9711 commit da43aab
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[submodule "treeple/_lib/sklearn_fork"]
[submodule "treeple/_lib/sklearn"]
path = treeple/_lib/sklearn_fork
url = https://github.com/neurodata/scikit-learn
branch = submodulev3
4 changes: 1 addition & 3 deletions treeple/experimental/tests/test_sdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ def test_max_samples():
def test_sklearn_compatible_estimator(estimator, check):
# 1. check_class_weight_classifiers is not supported since it requires sample weight
# XXX: can include this "generalization" in the future if it's useful
if check.func.__name__ in [
"check_class_weight_classifiers",
]:
if check.func.__name__ in ["check_class_weight_classifiers", "check_sample_weight_equivalence"]:
pytest.skip()
check(estimator)
1 change: 1 addition & 0 deletions treeple/tests/test_honest_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def test_sklearn_compatible_estimator(estimator, check):
# for fitting the tree's splits
if check.func.__name__ in [
"check_class_weight_classifiers",
"check_sample_weight_equivalence",
# TODO: this is an error. Somehow a segfault is raised when fit is called first and
# then partial_fit
"check_fit_score_takes_y",
Expand Down
26 changes: 14 additions & 12 deletions treeple/tree/_honest_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,23 +772,25 @@ def _set_leaf_nodes(self, leaf_ids, y, sample_weight):
self.tree_.value[:, :, :] = 0

# XXX: Note this method does not make these into a proportion of the leaf
total_n_node_samples = 0.0
# total_n_node_samples = 0.0

# apply sample-weight to the leaf nodes
seen_leaf_ids = set()
# seen_leaf_ids = set()
for leaf_id, yval, y_weight in zip(
leaf_ids, y[self.honest_indices_, :], sample_weight[self.honest_indices_]
):
total_n_node_samples += y_weight

if leaf_id in seen_leaf_ids:
self.tree_.value[leaf_id][:, yval] += y_weight
else:
self.tree_.value[leaf_id][:, yval] = y_weight
seen_leaf_ids.add(leaf_id)

for leaf_id in seen_leaf_ids:
self.tree_.value[leaf_id] /= total_n_node_samples
# XXX: this treats the leaf node values as a sum of the leaf
self.tree_.value[leaf_id][:, yval] += y_weight

# XXX: this normalizes the leaf node values to be a proportion of the leaf
# total_n_node_samples += y_weight
# if leaf_id in seen_leaf_ids:
# self.tree_.value[leaf_id][:, yval] += y_weight
# else:
# self.tree_.value[leaf_id][:, yval] = y_weight
# seen_leaf_ids.add(leaf_id)
# for leaf_id in seen_leaf_ids:
# self.tree_.value[leaf_id] /= total_n_node_samples

def _inherit_estimator_attributes(self):
"""Initialize necessary attributes from the provided tree estimator"""
Expand Down
1 change: 1 addition & 0 deletions treeple/tree/tests/test_honest_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def test_sklearn_compatible_estimator(estimator, check):
"check_class_weight_classifiers",
"check_classifier_multioutput",
"check_do_not_raise_errors_in_init_or_set_params",
"check_sample_weight_equivalence",
]:
pytest.skip()
check(estimator)
Expand Down

0 comments on commit da43aab

Please sign in to comment.