diff --git a/tests/assemblers/test_boosting_lightgbm.py b/tests/assemblers/test_boosting_lightgbm.py index 743c83c5..90d1dcdc 100644 --- a/tests/assemblers/test_boosting_lightgbm.py +++ b/tests/assemblers/test_boosting_lightgbm.py @@ -64,18 +64,18 @@ def test_regression(): expected = ast.BinNumExpr( ast.IfExpr( ast.CompExpr( - ast.FeatureRef(12), - ast.NumVal(9.725), + ast.FeatureRef(8), + ast.NumVal(1.0000000180025095e-35), ast.CompOpType.GT), - ast.NumVal(22.030283219508686), - ast.NumVal(23.27840740210207)), + ast.NumVal(156.64462853604854), + ast.NumVal(148.40956590509697)), ast.IfExpr( ast.CompExpr( - ast.FeatureRef(5), - ast.NumVal(6.8375), + ast.FeatureRef(2), + ast.NumVal(0.00780560282464346), ast.CompOpType.GT), - ast.NumVal(1.2777791671888081), - ast.NumVal(-0.2686772850549309)), + ast.NumVal(4.996373375352607), + ast.NumVal(-3.1063596100284814)), ast.BinNumOpType.ADD) assert utils.cmp_exprs(actual, expected) @@ -93,18 +93,18 @@ def test_regression_random_forest(): ast.BinNumExpr( ast.IfExpr( ast.CompExpr( - ast.FeatureRef(12), - ast.NumVal(9.605), + ast.FeatureRef(2), + ast.NumVal(0.00780560282464346), ast.CompOpType.GT), - ast.NumVal(17.398543657369768), - ast.NumVal(29.851408659650296)), + ast.NumVal(210.27118647591766), + ast.NumVal(120.45454548930705)), ast.IfExpr( ast.CompExpr( - ast.FeatureRef(5), - ast.NumVal(6.888), - ast.CompOpType.GT), - ast.NumVal(37.2235298136268), - ast.NumVal(19.948122884684025)), + ast.FeatureRef(2), + ast.NumVal(-0.007822672246629598), + ast.CompOpType.LTE), + ast.NumVal(114.24161077349474), + ast.NumVal(194.84868424576604)), ast.BinNumOpType.ADD), ast.NumVal(0.5), ast.BinNumOpType.MUL) @@ -159,19 +159,19 @@ def test_simple_sigmoid_output_transform(): ast.BinNumExpr( ast.IfExpr( ast.CompExpr( - ast.FeatureRef(12), - ast.NumVal(19.23), - ast.CompOpType.GT), - ast.NumVal(4.002437528537838), - ast.NumVal(4.090096709787509)), + ast.FeatureRef(8), + ast.NumVal(-0.0028501970360456344), + ast.CompOpType.LTE), + ast.NumVal(5.8325360677435345), + ast.NumVal(5.891973988308211)), ast.IfExpr( ast.CompExpr( - ast.FeatureRef(12), - ast.NumVal(14.895), - ast.CompOpType.GT), - ast.NumVal(-0.0417499606641773), - ast.NumVal(0.02069953712454655)), - ast.BinNumOpType.ADD)) + ast.FeatureRef(8), + ast.NumVal(-0.005612778088288765), + ast.CompOpType.LTE), + ast.NumVal(-0.027170480653266372), + ast.NumVal(0.026423953384869338)), + ast.BinNumOpType.ADD) assert utils.cmp_exprs(actual, expected) @@ -188,18 +188,18 @@ def test_log1p_exp_output_transform(): ast.BinNumExpr( ast.IfExpr( ast.CompExpr( - ast.FeatureRef(12), - ast.NumVal(19.23), - ast.CompOpType.GT), - ast.NumVal(0.6622623010380544), - ast.NumVal(0.6684065452877841)), + ast.FeatureRef(8), + ast.NumVal(-0.0028501970360456344), + ast.CompOpType.LTE), + ast.NumVal(0.693713164308067), + ast.NumVal(0.694435273176687)), ast.IfExpr( ast.CompExpr( - ast.FeatureRef(12), - ast.NumVal(15.145), - ast.CompOpType.GT), - ast.NumVal(0.1404975120475147), - ast.NumVal(0.14535916856709272)), + ast.FeatureRef(8), + ast.NumVal(-0.005612778088288765), + ast.CompOpType.LTE), + ast.NumVal(0.14830023030115363), + ast.NumVal(0.14902176200722345)), ast.BinNumOpType.ADD))) assert utils.cmp_exprs(actual, expected) @@ -216,18 +216,18 @@ def test_maybe_sqr_output_transform(): ast.BinNumExpr( ast.IfExpr( ast.CompExpr( - ast.FeatureRef(12), - ast.NumVal(9.725), + ast.FeatureRef(8), + ast.NumVal(1.0000000180025095e-35), ast.CompOpType.GT), - ast.NumVal(4.569350528717041), - ast.NumVal(4.663526439666748)), + ast.NumVal(12.094032478332519), + ast.NumVal(11.671793556213379)), ast.IfExpr( ast.CompExpr( - ast.FeatureRef(12), - ast.NumVal(11.655), - ast.CompOpType.GT), - ast.NumVal(-0.04462450027465819), - ast.NumVal(0.033305134773254384)), + ast.FeatureRef(8), + ast.NumVal(-0.00468258384360457), + ast.CompOpType.LTE), + ast.NumVal(-0.18738342285156248), + ast.NumVal(0.19059675216674812)), ast.BinNumOpType.ADD), to_reuse=True) @@ -250,18 +250,18 @@ def test_exp_output_transform(): ast.BinNumExpr( ast.IfExpr( ast.CompExpr( - ast.FeatureRef(12), - ast.NumVal(9.725), + ast.FeatureRef(8), + ast.NumVal(1.0000000180025095e-35), ast.CompOpType.GT), - ast.NumVal(3.1043985065105892), - ast.NumVal(3.1318783133960197)), + ast.NumVal(5.040167360736721), + ast.NumVal(5.013324518244505)), ast.IfExpr( ast.CompExpr( - ast.FeatureRef(5), - ast.NumVal(6.8375), + ast.FeatureRef(2), + ast.NumVal(0.00780560282464346), ast.CompOpType.GT), - ast.NumVal(0.028409619436010138), - ast.NumVal(-0.0060740730485278754)), + ast.NumVal(0.016475080997255653), + ast.NumVal(-0.010346335106608635)), ast.BinNumOpType.ADD)) assert utils.cmp_exprs(actual, expected) diff --git a/tests/test_cli.py b/tests/test_cli.py index 7b7681b0..4dc0fc0a 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -76,8 +76,8 @@ def test_generate_code(pickled_model): verify_python_model_is_expected( generated_code, - [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], - expected_output=-44.40540274041321) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + expected_output=11089.941259597403) def test_function_name(pickled_model): @@ -151,5 +151,5 @@ def test_unsupported_args_are_ignored(pickled_model): verify_python_model_is_expected( generated_code, - [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], - expected_output=-44.40540274041321) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + expected_output=11089.941259597403) diff --git a/tests/utils.py b/tests/utils.py index a1b393c3..eea5d062 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -265,7 +265,7 @@ def verify_python_model_is_expected(model_code, input, expected_output): context = {} exec(code, context) - print(context["result"]) + assert np.isclose(context["result"], expected_output)