Skip to content

Commit

Permalink
TST check that probability sum to one for all events
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre committed Jul 2, 2024
1 parent 9f9f750 commit 7595c7f
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions hazardous/tests/test_survival_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,24 +96,29 @@ def test_survival_boost_predict_proba(seed):
time_horizon = 0
y_pred = est.predict_proba(X_test, time_horizon=time_horizon)
assert y_pred.shape == (n_events, X_test.shape[0], 1)
assert_allclose(y_pred.sum(axis=0), 1.0)

time_horizon = [0, 10]
y_pred = est.predict_proba(X_test, time_horizon=time_horizon)
assert y_pred.shape == (n_events, X_test.shape[0], len(time_horizon))
assert_allclose(y_pred.sum(axis=0), 1.0)

time_horizon = 0
est.set_params(time_horizon=time_horizon)
y_pred = est.predict_proba(X_test)
assert y_pred.shape == (n_events, X_test.shape[0], 1)
assert_allclose(y_pred.sum(axis=0), 1.0)

time_horizon = [0, 10]
est.set_params(time_horizon=time_horizon)
y_pred = est.predict_proba(X_test)
assert y_pred.shape == (n_events, X_test.shape[0], len(time_horizon))
assert_allclose(y_pred.sum(axis=0), 1.0)

time_horizon = [0, 10, 20]
y_pred = est.predict_proba(X_test, time_horizon=time_horizon)
assert y_pred.shape == (n_events, X_test.shape[0], len(time_horizon))
assert_allclose(y_pred.sum(axis=0), 1.0)


@pytest.mark.parametrize("seed", SEED_RANGE)
Expand Down

0 comments on commit 7595c7f

Please sign in to comment.