diff --git a/src/gurobi_ml/modeling/decision_tree/decision_tree_model.py b/src/gurobi_ml/modeling/decision_tree/decision_tree_model.py index c8f1e7cb..62b2e182 100644 --- a/src/gurobi_ml/modeling/decision_tree/decision_tree_model.py +++ b/src/gurobi_ml/modeling/decision_tree/decision_tree_model.py @@ -126,6 +126,8 @@ def _leaf_formulation( # We should attain 1 leaf gp_model.addConstr(leafs_vars.sum(axis=1) == 1) + gp_model.addConstr(output <= np.max(tree["value"], axis=0)) + gp_model.addConstr(output >= np.min(tree["value"], axis=0)) if verbose: timer.timing(f"Added {nex} linear constraints") @@ -207,8 +209,8 @@ def _paths_formulation(gp_model, _input, output, tree, epsilon, _name_var): for i in range(outdim) ) - output.LB = np.min(tree.value) - output.UB = np.max(tree.value) + gp_model.addConstr(output <= np.max(tree["value"], axis=0)) + gp_model.addConstr(output >= np.min(tree["value"], axis=0)) class AbstractTreeEstimator(AbstractPredictorConstr): diff --git a/tests/test_lightgbm/test_lightgbm_formulations.py b/tests/test_lightgbm/test_lightgbm_formulations.py index 8d264712..7eda3afe 100644 --- a/tests/test_lightgbm/test_lightgbm_formulations.py +++ b/tests/test_lightgbm/test_lightgbm_formulations.py @@ -14,7 +14,7 @@ class TestLGBMhModel(FixedRegressionModel): basedir = os.path.join(os.path.dirname(__file__), "..", "predictors") - def test_diabetes_xgboost_pairs(self): + def test_diabetes_lightgbm_pairs(self): data = datasets.load_diabetes() X = data["data"] y = data["target"] @@ -25,7 +25,7 @@ def test_diabetes_xgboost_pairs(self): self.do_one_case(one_case, X, 6, "pairs", float_type=np.float32) - def test_diabetes_xgboost_pairs_pipeline(self): + def test_diabetes_lightgbm_pairs_pipeline(self): data = datasets.load_diabetes() X = data["data"] y = data["target"] @@ -37,7 +37,7 @@ def test_diabetes_xgboost_pairs_pipeline(self): self.do_one_case(one_case, X, 6, "pairs", float_type=np.float32) - def test_diabetes_xgboost_all(self): + def test_diabetes_lightgbm_all(self): data = datasets.load_diabetes() X = data["data"] y = data["target"]