Skip to content

Commit

Permalink
Fix #168. Enforce float32 type for split condition values for GBT mod…
Browse files Browse the repository at this point in the history
…els created using XGBoost (#188)
  • Loading branch information
izeigerman authored Apr 5, 2020
1 parent 52c601b commit 04767ab
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 14 deletions.
2 changes: 1 addition & 1 deletion m2cgen/assemblers/boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _assemble_tree(self, tree):
if "leaf" in tree:
return ast.NumVal(tree["leaf"])

threshold = ast.NumVal(tree["split_condition"])
threshold = ast.NumVal(tree["split_condition"], dtype=np.float32)
split = tree["split"]
feature_idx = self._feature_name_to_idx.get(split, split)
feature_ref = ast.FeatureRef(feature_idx)
Expand Down
4 changes: 3 additions & 1 deletion m2cgen/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ class NumExpr(Expr):


class NumVal(NumExpr):
def __init__(self, value):
def __init__(self, value, dtype=None):
if dtype:
value = dtype(value)
self.value = value

def __str__(self):
Expand Down
35 changes: 23 additions & 12 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,50 +35,51 @@


# Set of helper functions to make parametrization less verbose.
def regression(model):
def regression(model, test_fraction=0.02):
return (
model,
utils.get_regression_model_trainer(),
utils.get_regression_model_trainer(test_fraction),
REGRESSION,
)


def classification(model):
def classification(model, test_fraction=0.02):
return (
model,
utils.get_classification_model_trainer(),
utils.get_classification_model_trainer(test_fraction),
CLASSIFICATION,
)


def classification_binary(model):
def classification_binary(model, test_fraction=0.02):
return (
model,
utils.get_binary_classification_model_trainer(),
utils.get_binary_classification_model_trainer(test_fraction),
CLASSIFICATION,
)


def regression_random(model):
def regression_random(model, test_fraction=0.02):
return (
model,
utils.get_regression_random_data_model_trainer(0.01),
utils.get_regression_random_data_model_trainer(test_fraction),
REGRESSION,
)


def classification_random(model):
def classification_random(model, test_fraction=0.02):
return (
model,
utils.get_classification_random_data_model_trainer(0.01),
utils.get_classification_random_data_model_trainer(test_fraction),
CLASSIFICATION,
)


def classification_binary_random(model):
def classification_binary_random(model, test_fraction=0.02):
return (
model,
utils.get_classification_binary_random_data_model_trainer(0.01),
utils.get_classification_binary_random_data_model_trainer(
test_fraction),
CLASSIFICATION,
)

Expand All @@ -92,6 +93,8 @@ def classification_binary_random(model):
FOREST_PARAMS = dict(n_estimators=10, random_state=RANDOM_SEED)
XGBOOST_PARAMS = dict(base_score=0.6, n_estimators=10,
random_state=RANDOM_SEED)
XGBOOST_HIST_PARAMS = dict(base_score=0.6, n_estimators=10,
tree_method="hist", random_state=RANDOM_SEED)
XGBOOST_PARAMS_LINEAR = dict(base_score=0.6, n_estimators=10,
feature_selector="shuffle", booster="gblinear",
random_state=RANDOM_SEED)
Expand Down Expand Up @@ -170,6 +173,14 @@ def classification_binary_random(model):
classification(xgboost.XGBClassifier(**XGBOOST_PARAMS)),
classification_binary(xgboost.XGBClassifier(**XGBOOST_PARAMS)),
# XGBoost (tree method "hist")
regression(xgboost.XGBRegressor(**XGBOOST_HIST_PARAMS),
test_fraction=0.2),
classification(xgboost.XGBClassifier(**XGBOOST_HIST_PARAMS),
test_fraction=0.2),
classification_binary(xgboost.XGBClassifier(**XGBOOST_HIST_PARAMS),
test_fraction=0.2),
# XGBoost (LINEAR)
regression(xgboost.XGBRegressor(**XGBOOST_PARAMS_LINEAR)),
classification(xgboost.XGBClassifier(**XGBOOST_PARAMS_LINEAR)),
Expand Down
7 changes: 7 additions & 0 deletions tests/test_ast.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from m2cgen import ast


Expand Down Expand Up @@ -69,3 +70,9 @@ def test_count_all_exprs_types():
ast.BinNumOpType.MUL)

assert ast.count_exprs(expr) == 27


def test_num_val():
assert type(ast.NumVal(1).value) == int
assert type(ast.NumVal(1, dtype=np.float32).value) == np.float32
assert type(ast.NumVal(1, dtype=np.float64).value) == np.float64

0 comments on commit 04767ab

Please sign in to comment.