Skip to content

Commit

Permalink
Fix unleved bug with bert (#138)
Browse files Browse the repository at this point in the history
  • Loading branch information
mirand863 authored Dec 5, 2024
1 parent a344014 commit 44d8e0f
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 18 deletions.
4 changes: 3 additions & 1 deletion hiclass/HierarchicalClassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ def _pre_fit(self, X, y, sample_weight):
)
else:
self.X_ = np.array(X)
self.y_ = np.array(y)
self.y_ = check_array(
make_leveled(y), dtype=None, ensure_2d=False, allow_nd=True
)

if sample_weight is not None:
self.sample_weight_ = _check_sample_weight(sample_weight, X)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"ray",
"shap==0.44.1",
"xarray==2023.1.0",
"bert-sklearn @ git+https://github.com/charles9n/bert-sklearn.git#egg=bert-sklearn",
],
}

Expand Down
38 changes: 37 additions & 1 deletion tests/test_LocalClassifierPerParentNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import networkx as nx
import numpy as np
import pytest
from numpy.testing import assert_array_equal, assert_array_almost_equal
from bert_sklearn import BertClassifier
from numpy.testing import assert_array_almost_equal, assert_array_equal
from scipy.sparse import csr_matrix
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import LogisticRegression
from sklearn.utils.estimator_checks import parametrize_with_checks
from sklearn.utils.validation import check_is_fitted

from hiclass import LocalClassifierPerParentNode
from hiclass._calibration.Calibrator import _Calibrator
from hiclass.HierarchicalClassifier import make_leveled
Expand Down Expand Up @@ -393,3 +395,37 @@ def test_fit_calibrate_predict_predict_proba_bert():
classifier.calibrate(x, y)
classifier.predict(x)
classifier.predict_proba(x)


# Note: bert only works with the local classifier per parent node
# It does not have the attribute classes_, which are necessary
# for the local classifiers per level and per node
def test_fit_bert():
bert = BertClassifier()
clf = LocalClassifierPerParentNode(
local_classifier=bert,
bert=True,
)
x = ["Batman", "rorschach"]
y = [
["Action", "The Dark Night"],
["Action", "Watchmen"],
]
clf.fit(x, y)
check_is_fitted(clf)
predictions = clf.predict(x)
assert_array_equal(y, predictions)


def test_bert_unleveled():
clf = LocalClassifierPerParentNode(
local_classifier=BertClassifier(),
bert=True,
)
x = ["Batman", "Jaws"]
y = [["Action", "The Dark Night"], ["Thriller"]]
ground_truth = [["Action", "The Dark Night"], ["Action", "The Dark Night"]]
clf.fit(x, y)
check_is_fitted(clf)
predictions = clf.predict(x)
assert_array_equal(ground_truth, predictions)
17 changes: 1 addition & 16 deletions tests/test_LocalClassifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from sklearn.utils.validation import check_is_fitted

from hiclass import (
LocalClassifierPerNode,
LocalClassifierPerLevel,
LocalClassifierPerNode,
LocalClassifierPerParentNode,
)
from hiclass.ConstantClassifier import ConstantClassifier
Expand Down Expand Up @@ -75,21 +75,6 @@ def test_empty_levels(empty_levels, classifier):
assert_array_equal(ground_truth, predictions)


@pytest.mark.parametrize("classifier", classifiers)
def test_fit_bert(classifier):
bert = ConstantClassifier()
clf = classifier(
local_classifier=bert,
bert=True,
)
X = ["Text 1", "Text 2"]
y = ["a", "a"]
clf.fit(X, y)
check_is_fitted(clf)
predictions = clf.predict(X)
assert_array_equal(y, predictions)


@pytest.mark.parametrize("classifier", classifiers)
def test_knn(classifier):
knn = KNeighborsClassifier(
Expand Down

0 comments on commit 44d8e0f

Please sign in to comment.