diff --git a/hazardous/tests/test_survival_boost.py b/hazardous/tests/test_survival_boost.py index e8b3518..53eb1f0 100644 --- a/hazardous/tests/test_survival_boost.py +++ b/hazardous/tests/test_survival_boost.py @@ -96,24 +96,29 @@ def test_survival_boost_predict_proba(seed): time_horizon = 0 y_pred = est.predict_proba(X_test, time_horizon=time_horizon) assert y_pred.shape == (n_events, X_test.shape[0], 1) + assert_allclose(y_pred.sum(axis=0), 1.0) time_horizon = [0, 10] y_pred = est.predict_proba(X_test, time_horizon=time_horizon) assert y_pred.shape == (n_events, X_test.shape[0], len(time_horizon)) + assert_allclose(y_pred.sum(axis=0), 1.0) time_horizon = 0 est.set_params(time_horizon=time_horizon) y_pred = est.predict_proba(X_test) assert y_pred.shape == (n_events, X_test.shape[0], 1) + assert_allclose(y_pred.sum(axis=0), 1.0) time_horizon = [0, 10] est.set_params(time_horizon=time_horizon) y_pred = est.predict_proba(X_test) assert y_pred.shape == (n_events, X_test.shape[0], len(time_horizon)) + assert_allclose(y_pred.sum(axis=0), 1.0) time_horizon = [0, 10, 20] y_pred = est.predict_proba(X_test, time_horizon=time_horizon) assert y_pred.shape == (n_events, X_test.shape[0], len(time_horizon)) + assert_allclose(y_pred.sum(axis=0), 1.0) @pytest.mark.parametrize("seed", SEED_RANGE)