diff --git a/m2cgen/assemblers/boosting.py b/m2cgen/assemblers/boosting.py index e327b774..714da1b6 100644 --- a/m2cgen/assemblers/boosting.py +++ b/m2cgen/assemblers/boosting.py @@ -134,7 +134,7 @@ def _assemble_tree(self, tree): if "leaf" in tree: return ast.NumVal(tree["leaf"]) - threshold = ast.NumVal(tree["split_condition"]) + threshold = ast.NumVal(tree["split_condition"], dtype=np.float32) split = tree["split"] feature_idx = self._feature_name_to_idx.get(split, split) feature_ref = ast.FeatureRef(feature_idx) diff --git a/m2cgen/ast.py b/m2cgen/ast.py index 9b49936b..f4195af7 100644 --- a/m2cgen/ast.py +++ b/m2cgen/ast.py @@ -30,7 +30,9 @@ class NumExpr(Expr): class NumVal(NumExpr): - def __init__(self, value): + def __init__(self, value, dtype=None): + if dtype: + value = dtype(value) self.value = value def __str__(self): diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index b92185f8..8b6e8baa 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -35,50 +35,51 @@ # Set of helper functions to make parametrization less verbose. -def regression(model): +def regression(model, test_fraction=0.02): return ( model, - utils.get_regression_model_trainer(), + utils.get_regression_model_trainer(test_fraction), REGRESSION, ) -def classification(model): +def classification(model, test_fraction=0.02): return ( model, - utils.get_classification_model_trainer(), + utils.get_classification_model_trainer(test_fraction), CLASSIFICATION, ) -def classification_binary(model): +def classification_binary(model, test_fraction=0.02): return ( model, - utils.get_binary_classification_model_trainer(), + utils.get_binary_classification_model_trainer(test_fraction), CLASSIFICATION, ) -def regression_random(model): +def regression_random(model, test_fraction=0.02): return ( model, - utils.get_regression_random_data_model_trainer(0.01), + utils.get_regression_random_data_model_trainer(test_fraction), REGRESSION, ) -def classification_random(model): +def classification_random(model, test_fraction=0.02): return ( model, - utils.get_classification_random_data_model_trainer(0.01), + utils.get_classification_random_data_model_trainer(test_fraction), CLASSIFICATION, ) -def classification_binary_random(model): +def classification_binary_random(model, test_fraction=0.02): return ( model, - utils.get_classification_binary_random_data_model_trainer(0.01), + utils.get_classification_binary_random_data_model_trainer( + test_fraction), CLASSIFICATION, ) @@ -92,6 +93,8 @@ def classification_binary_random(model): FOREST_PARAMS = dict(n_estimators=10, random_state=RANDOM_SEED) XGBOOST_PARAMS = dict(base_score=0.6, n_estimators=10, random_state=RANDOM_SEED) +XGBOOST_HIST_PARAMS = dict(base_score=0.6, n_estimators=10, + tree_method="hist", random_state=RANDOM_SEED) XGBOOST_PARAMS_LINEAR = dict(base_score=0.6, n_estimators=10, feature_selector="shuffle", booster="gblinear", random_state=RANDOM_SEED) @@ -170,6 +173,14 @@ def classification_binary_random(model): classification(xgboost.XGBClassifier(**XGBOOST_PARAMS)), classification_binary(xgboost.XGBClassifier(**XGBOOST_PARAMS)), + # XGBoost (tree method "hist") + regression(xgboost.XGBRegressor(**XGBOOST_HIST_PARAMS), + test_fraction=0.2), + classification(xgboost.XGBClassifier(**XGBOOST_HIST_PARAMS), + test_fraction=0.2), + classification_binary(xgboost.XGBClassifier(**XGBOOST_HIST_PARAMS), + test_fraction=0.2), + # XGBoost (LINEAR) regression(xgboost.XGBRegressor(**XGBOOST_PARAMS_LINEAR)), classification(xgboost.XGBClassifier(**XGBOOST_PARAMS_LINEAR)), diff --git a/tests/test_ast.py b/tests/test_ast.py index 36b9421a..32c4b4f4 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -1,3 +1,4 @@ +import numpy as np from m2cgen import ast @@ -69,3 +70,9 @@ def test_count_all_exprs_types(): ast.BinNumOpType.MUL) assert ast.count_exprs(expr) == 27 + + +def test_num_val(): + assert type(ast.NumVal(1).value) == int + assert type(ast.NumVal(1, dtype=np.float32).value) == np.float32 + assert type(ast.NumVal(1, dtype=np.float64).value) == np.float64