diff --git a/m2cgen/assemblers/boosting.py b/m2cgen/assemblers/boosting.py index f9261962..4e3d7b7f 100644 --- a/m2cgen/assemblers/boosting.py +++ b/m2cgen/assemblers/boosting.py @@ -9,7 +9,7 @@ class BaseBoostingAssembler(ModelAssembler): classifier_name = None - def __init__(self, model, trees, base_score=0): + def __init__(self, model, trees, base_score=0, tree_limit=None): super().__init__(model) self.all_trees = trees self._base_score = base_score @@ -17,6 +17,9 @@ def __init__(self, model, trees, base_score=0): self._output_size = 1 self._is_classification = False + assert tree_limit is None or tree_limit > 0, "Unexpected tree limit" + self._tree_limit = tree_limit + model_class_name = type(model).__name__ if model_class_name == self.classifier_name: self._is_classification = True @@ -34,6 +37,9 @@ def assemble(self): self.all_trees, self._base_score) def _assemble_single_output(self, trees, base_score=0): + if self._tree_limit: + trees = trees[:self._tree_limit] + trees_ast = [self._assemble_tree(t) for t in trees] result_ast = utils.apply_op_to_expressions( ast.BinNumOpType.ADD, @@ -83,16 +89,14 @@ def __init__(self, model): } model_dump = model.get_booster().get_dump(dump_format="json") - - # Respect XGBoost ntree_limit - ntree_limit = getattr(model, "best_ntree_limit", 0) - - if ntree_limit > 0: - model_dump = model_dump[:ntree_limit] - trees = [json.loads(d) for d in model_dump] - super().__init__(model, trees, base_score=model.base_score) + # Limit the number of trees that should be used for + # assembling (if applicable). + best_ntree_limit = getattr(model, "best_ntree_limit", None) + + super().__init__(model, trees, base_score=model.base_score, + tree_limit=best_ntree_limit) def _assemble_tree(self, tree): if "leaf" in tree: diff --git a/tests/assemblers/test_xgboost.py b/tests/assemblers/test_xgboost.py index ebbb83b9..d1a39c6a 100644 --- a/tests/assemblers/test_xgboost.py +++ b/tests/assemblers/test_xgboost.py @@ -147,3 +147,84 @@ def test_regression_best_ntree_limit(): ast.BinNumOpType.ADD)) assert utils.cmp_exprs(actual, expected) + + +def test_multi_class_best_ntree_limit(): + base_score = 0.5 + estimator = xgboost.XGBClassifier(n_estimators=100, random_state=1, + max_depth=1, base_score=base_score) + + estimator.best_ntree_limit = 1 + + utils.train_model_classification(estimator) + + assembler = assemblers.XGBoostModelAssembler(estimator) + actual = assembler.assemble() + + estimator_exp_class1 = ast.ExpExpr( + ast.SubroutineExpr( + ast.BinNumExpr( + ast.NumVal(0.5), + ast.IfExpr( + ast.CompExpr( + ast.FeatureRef(2), + ast.NumVal(2.5999999), + ast.CompOpType.GTE), + ast.NumVal(-0.0731707439), + ast.NumVal(0.142857149)), + ast.BinNumOpType.ADD)), + to_reuse=True) + + estimator_exp_class2 = ast.ExpExpr( + ast.SubroutineExpr( + ast.BinNumExpr( + ast.NumVal(0.5), + ast.IfExpr( + ast.CompExpr( + ast.FeatureRef(2), + ast.NumVal(2.5999999), + ast.CompOpType.GTE), + ast.NumVal(0.0341463387), + ast.NumVal(-0.0714285821)), + ast.BinNumOpType.ADD)), + to_reuse=True) + + estimator_exp_class3 = ast.ExpExpr( + ast.SubroutineExpr( + ast.BinNumExpr( + ast.NumVal(0.5), + ast.IfExpr( + ast.CompExpr( + ast.FeatureRef(2), + ast.NumVal(4.85000038), + ast.CompOpType.GTE), + ast.NumVal(0.129441619), + ast.NumVal(-0.0681440532)), + ast.BinNumOpType.ADD)), + to_reuse=True) + + exp_sum = ast.BinNumExpr( + ast.BinNumExpr( + estimator_exp_class1, + estimator_exp_class2, + ast.BinNumOpType.ADD), + estimator_exp_class3, + ast.BinNumOpType.ADD, + to_reuse=True) + + expected = ast.VectorVal([ + ast.BinNumExpr( + estimator_exp_class1, + exp_sum, + ast.BinNumOpType.DIV), + ast.BinNumExpr( + estimator_exp_class2, + exp_sum, + ast.BinNumOpType.DIV), + ast.BinNumExpr( + estimator_exp_class3, + exp_sum, + ast.BinNumOpType.DIV) + ]) + + assert utils.cmp_exprs(actual, expected)