Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Cleanup multiview API and enable oblique multiview #265

Open
wants to merge 40 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
3ddb98b
Fix multiview API and enable oblique multiview
adam2392 Mar 14, 2024
923c171
Clean up other unused kwarg path
adam2392 Mar 14, 2024
500dca9
add changelog
adam2392 Mar 14, 2024
954c6fc
Fix examples
adam2392 Mar 16, 2024
626b049
Merge branch 'main' into mvapi
adam2392 Apr 9, 2024
efc4440
Merging
adam2392 Jun 11, 2024
038af43
Merge branch 'main' into mvapi
adam2392 Jun 20, 2024
16f4af4
Merging main
adam2392 Jun 24, 2024
e6ea30b
Almost working
adam2392 Jun 24, 2024
ffdef5b
Merge branch 'main' into mvapi
adam2392 Jun 24, 2024
1218680
Fix changelog
adam2392 Jun 24, 2024
451112b
Merge branch 'mvapi' of https://github.com/adam2392/scikit-tree into …
adam2392 Jun 24, 2024
bba2e8f
Update submodule
adam2392 Jun 26, 2024
01cab41
WIP
adam2392 Jul 3, 2024
04daabe
WIP
adam2392 Jul 3, 2024
ac87b07
Working prototype for multiview oblique
adam2392 Jul 5, 2024
81493fe
Working prototype for multiview oblique
adam2392 Jul 5, 2024
307fb85
Merge branch 'main' into mvapi
adam2392 Jul 5, 2024
5fe234d
Merge branch 'mvapi' of https://github.com/adam2392/scikit-tree into …
adam2392 Jul 5, 2024
c6824b9
Add mvrf
adam2392 Jul 5, 2024
f546194
Enable multiview oblique rf tests
adam2392 Jul 5, 2024
a1c6313
Add to api.rst
adam2392 Jul 5, 2024
48997cf
Add to api.rst
adam2392 Jul 5, 2024
ef1dc6b
Fix unit tests
adam2392 Jul 5, 2024
9dbac9b
Fix unit tests
adam2392 Jul 5, 2024
7aeb6a6
Fix unit tests
adam2392 Jul 5, 2024
64edda3
Remove runtime checks in cython
adam2392 Jul 5, 2024
65b0f30
Fix docs
adam2392 Jul 5, 2024
952081d
Merging
adam2392 Jul 9, 2024
f1673ea
Merge branch 'main' into mvapi
adam2392 Jul 9, 2024
aafeb69
Removing
adam2392 Jul 9, 2024
a003d5a
Fix
adam2392 Jul 9, 2024
c1f9257
New submodule
adam2392 Jul 9, 2024
9f48e41
Merge branch 'mvapi' of https://github.com/adam2392/treeple into mvapi
adam2392 Jul 9, 2024
b0e0dff
Fix import
adam2392 Jul 9, 2024
392a729
Fix import
adam2392 Jul 9, 2024
347dddb
Fix import
adam2392 Jul 9, 2024
f673a0e
Merge branch 'main' into mvapi
PSSF23 Jul 15, 2024
61144d5
Merge branch 'main' into mvapi
sampan501 Jul 23, 2024
90320b1
Merge branch 'main' into mvapi
adam2392 Aug 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ how scikit-learn builds trees.
PatchObliqueRandomForestRegressor
HonestForestClassifier
MultiViewRandomForestClassifier
MultiViewObliqueRandomForestClassifier

.. currentmodule:: treeple.tree
.. autosummary::
Expand All @@ -77,6 +78,7 @@ how scikit-learn builds trees.
PatchObliqueDecisionTreeRegressor
HonestTreeClassifier
MultiViewDecisionTreeClassifier
MultiViewObliqueDecisionTreeClassifier

Unsupervised
------------
Expand Down
10 changes: 10 additions & 0 deletions doc/whats_new/v0.9.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@ Note that the previous version of the package will still be available under the
Changelog
---------

- |API| :class:`treeple.tree.MultiViewDecisionTreeClassifier` do not have the
``apply_max_features_per_feature_set`` argument anymore. Instead, the
``max_features`` argument is used to control the number of features to
consider when looking for the best split within each feature set explicitly.
By `Adam Li`_ :pr:`#265`.

- |Feature| :class:`treeple.tree.MultiViewObliqueDecisionTreeClassifier` is implemented
along with its forest version :class:`treeple.MultiViewObliqueRandomForestClassifier`.
By `Adam Li`_ :pr:`#265`.

- |API| Rename the package to ``treeple``. By `SUKI-O`_ (:pr:`#292`)
- |Fix| Fixed a bug in the predict_proba function of the :class:`treeple.HonestForestClassifier` where posteriors
estimated on empty leaf with ``ignore`` prior would result in ``np.nan``
Expand Down
145 changes: 137 additions & 8 deletions examples/splitters/plot_multiview_axis_aligned_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
from matplotlib.colors import ListedColormap

from treeple._lib.sklearn.tree._criterion import Gini
from treeple.tree._oblique_splitter import MultiViewSplitterTester
from treeple.tree._oblique_splitter import MultiViewObliqueSplitterTester, MultiViewSplitterTester

criterion = Gini(1, np.array((0, 1)))
max_features = 5
max_features = 6
min_samples_leaf = 1
min_weight_leaf = 0.0
random_state = np.random.RandomState(10)
Expand All @@ -40,7 +40,7 @@
feature_set_ends = np.array([3, 5, 9], dtype=np.intp)
n_feature_sets = len(feature_set_ends)

max_features_per_set_ = None
max_features_per_set_ = np.array([2, 2, 2])
feature_combinations = 1
monotonic_cst = None
missing_value_feature_mask = None
Expand Down Expand Up @@ -99,7 +99,11 @@
for iend in feature_set_ends[1:]:
ax.axvline(iend - 0.5, color="black", linewidth=1)

ax.set(title="Sampled Projection Matrix", xlabel="Feature Index", ylabel="Projection Vector Index")
ax.set(
title="Sampled Projection Matrix: \nMultiview Axis Aligned Split with Equal Max_Features",
xlabel="Feature Index",
ylabel="Projection Vector Index",
)
ax.set_xticks(np.arange(feature_set_ends[-1]))
ax.set_yticks(np.arange(max_features))
ax.set_yticklabels(np.arange(max_features, dtype=int) + 1)
Expand All @@ -115,6 +119,7 @@
colorbar.set_label("Projection Weight (I.e. Sampled Feature From a Feature Set)")
colorbar.ax.set_yticklabels(["0", "1"])

fig.tight_layout()
plt.show()

# %%
Expand All @@ -127,9 +132,6 @@
# more than the second feature set, we can specify ``max_features_per_set`` as follows:
# ``max_features_per_set = [3, 1]``. This will sample from the first feature set three times
# and the second feature set once.
#
# .. note:: In practice, this is controlled by the ``apply_max_features_per_feature_set`` parameter
# in :class:`treeple.tree.MultiViewDecisionTreeClassifier`.

max_features_per_set_ = np.array([1, 2, 3], dtype=int)
max_features = np.sum(max_features_per_set_)
Expand Down Expand Up @@ -163,7 +165,11 @@
for iend in feature_set_ends[1:]:
ax.axvline(iend - 0.5, color="black", linewidth=1)

ax.set(title="Sampled Projection Matrix", xlabel="Feature Index", ylabel="Projection Vector Index")
ax.set(
title="Sampled Projection Matrix:\n Multiview Axis-aligned Splitter",
xlabel="Feature Index",
ylabel="Projection Vector Index",
)
ax.set_xticks(np.arange(feature_set_ends[-1]))
ax.set_yticks(np.arange(max_features))
ax.set_yticklabels(np.arange(max_features, dtype=int) + 1)
Expand All @@ -179,6 +185,129 @@
colorbar.set_label("Projection Weight (I.e. Sampled Feature From a Feature Set)")
colorbar.ax.set_yticklabels(["0", "1"])

fig.tight_layout()
plt.show()

# %%
# Sampling multiview oblique splits
# ---------------------------------
# The multi-view splitter can also sample oblique splits. The oblique splits are
# generated by sampling a projection matrix and then transforming the data into the
# projected space.

feature_combinations = 1.5
cross_feature_set_sampling = False
splitter = MultiViewObliqueSplitterTester(
criterion,
max_features,
min_samples_leaf,
min_weight_leaf,
random_state,
monotonic_cst,
feature_combinations,
feature_set_ends,
n_feature_sets,
max_features_per_set_,
cross_feature_set_sampling,
)
splitter.init_test(X, y, sample_weight, missing_value_feature_mask)

# sample the projection matrix
projection_matrix = splitter.sample_projection_matrix_py()
print(projection_matrix)

cmap = ListedColormap(["orange", "white", "green"])

# Create a heatmap to visualize the indices
fig, ax = plt.subplots(figsize=(6, 6))

ax.imshow(
projection_matrix, cmap=cmap, aspect=feature_set_ends[-1] / max_features, interpolation="none"
)
ax.axvline(feature_set_ends[0] - 0.5, color="black", linewidth=1, label="Feature Sets")
for iend in feature_set_ends[1:]:
ax.axvline(iend - 0.5, color="black", linewidth=1)

ax.set(
title="Sampled Projection Matrix:\n Multiview Oblique Splits W/O Cross-Feature Sampling",
xlabel="Feature Index",
ylabel="Projection Vector Index",
)
ax.set_xticks(np.arange(feature_set_ends[-1]))
ax.set_yticks(np.arange(max_features))
ax.set_yticklabels(np.arange(max_features, dtype=int) + 1)
ax.set_xticklabels(np.arange(feature_set_ends[-1], dtype=int) + 1)
ax.legend()

# Create a mappable object
sm = ScalarMappable(cmap=cmap)
sm.set_array([]) # You can set an empty array or values here

# Create a color bar with labels for each feature set
colorbar = fig.colorbar(sm, ax=ax, ticks=[0, 0.5, 1], format="%d")
colorbar.set_label("Projection Weight")
colorbar.ax.set_yticklabels(["-1", "0", "1"])

fig.tight_layout()
plt.show()

# %%
# Sampling multiview oblique splits with cross-feature-set sampling.
# Now, we can also sample across feature sets within each projection vector.

cross_feature_set_sampling = True
splitter = MultiViewObliqueSplitterTester(
criterion,
max_features,
min_samples_leaf,
min_weight_leaf,
random_state,
monotonic_cst,
feature_combinations,
feature_set_ends,
n_feature_sets,
max_features_per_set_,
cross_feature_set_sampling,
)
splitter.init_test(X, y, sample_weight, missing_value_feature_mask)

# sample the projection matrix
projection_matrix = splitter.sample_projection_matrix_py()
print(projection_matrix)

cmap = ListedColormap(["orange", "white", "green"])

# Create a heatmap to visualize the indices
fig, ax = plt.subplots(figsize=(6, 6))

ax.imshow(
projection_matrix, cmap=cmap, aspect=feature_set_ends[-1] / max_features, interpolation="none"
)
ax.axvline(feature_set_ends[0] - 0.5, color="black", linewidth=1, label="Feature Sets")
for iend in feature_set_ends[1:]:
ax.axvline(iend - 0.5, color="black", linewidth=1)

ax.set(
title="Sampled Projection Matrix:\n Multiview Oblique Splits W/ Cross-Feature Sampling",
xlabel="Feature Index",
ylabel="Projection Vector Index",
)
ax.set_xticks(np.arange(feature_set_ends[-1]))
ax.set_yticks(np.arange(max_features))
ax.set_yticklabels(np.arange(max_features, dtype=int) + 1)
ax.set_xticklabels(np.arange(feature_set_ends[-1], dtype=int) + 1)
ax.legend()

# Create a mappable object
sm = ScalarMappable(cmap=cmap)
sm.set_array([]) # You can set an empty array or values here

# Create a color bar with labels for each feature set
colorbar = fig.colorbar(sm, ax=ax, ticks=[0, 0.5, 1], format="%d")
colorbar.set_label("Projection Weight")
colorbar.ax.set_yticklabels(["-1", "0", "1"])

fig.tight_layout()
plt.show()

# %%
Expand Down
7 changes: 6 additions & 1 deletion treeple/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@
ExtraTreesRegressor,
)
from .neighbors import NearestNeighborsMetaEstimator
from .ensemble import ExtendedIsolationForest, MultiViewRandomForestClassifier
from .ensemble import (
ExtendedIsolationForest,
MultiViewRandomForestClassifier,
MultiViewObliqueRandomForestClassifier,
)
from .ensemble._unsupervised_forest import (
UnsupervisedRandomForest,
UnsupervisedObliqueRandomForest,
Expand Down Expand Up @@ -88,4 +92,5 @@
"ExtraTreesRegressor",
"ExtendedIsolationForest",
"MultiViewRandomForestClassifier",
"MultiViewObliqueRandomForestClassifier",
]
2 changes: 1 addition & 1 deletion treeple/ensemble/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ._eiforest import ExtendedIsolationForest
from ._honest_forest import HonestForestClassifier
from ._multiview import MultiViewRandomForestClassifier
from ._multiview import MultiViewObliqueRandomForestClassifier, MultiViewRandomForestClassifier
from ._supervised_forest import (
ExtraObliqueRandomForestClassifier,
ExtraObliqueRandomForestRegressor,
Expand Down
Loading
Loading