Skip to content

Commit

Permalink
DOC add docstring to test
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre committed Jul 2, 2024
1 parent c14c21a commit 9f9f750
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions hazardous/tests/test_survival_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ def test_survival_boost_incidence_and_survival(seed):

@pytest.mark.parametrize("seed", SEED_RANGE)
def test_survival_boost_predict_proba(seed):
"""Check the behaviour of the `predict_proba` method. Notably, we check:
- we raise an error when no `time_horizon` is set in the constructor nor
passed to `predict_proba`.
- we can pass `time_horizon` as a parameter to `predict_proba`.
- we can set `time_horizon` as a constructor parameter of `SurvivalBoost`.
- the `time_horizon` parameter passed to `predict_proba` is prioritized
over the constructor parameter.
"""
X, y = make_synthetic_competing_weibull(return_X_y=True, random_state=seed)
assert sorted(y["event"].unique()) == [0, 1, 2, 3]
n_events = 4
Expand All @@ -78,15 +87,12 @@ def test_survival_boost_predict_proba(seed):
est = SurvivalBoost(n_iter=3, show_progressbar=False, random_state=seed)
est.fit(X_train, y_train)

# Raise an error when `time_horizon` was not set in the constructor nor in
# passed to `predict_proba`.
err_msg = (
"The time_horizon parameter is required to use SurvivalBoost as a classifier."
)
with pytest.raises(ValueError, match=err_msg):
est.predict_proba(X_test)

# Passing `time_horizon` as a parameter to `predict_proba`.
time_horizon = 0
y_pred = est.predict_proba(X_test, time_horizon=time_horizon)
assert y_pred.shape == (n_events, X_test.shape[0], 1)
Expand All @@ -95,7 +101,6 @@ def test_survival_boost_predict_proba(seed):
y_pred = est.predict_proba(X_test, time_horizon=time_horizon)
assert y_pred.shape == (n_events, X_test.shape[0], len(time_horizon))

# Setting `time_horizon` as a constructor parameter of `SurvivalBoost`.
time_horizon = 0
est.set_params(time_horizon=time_horizon)
y_pred = est.predict_proba(X_test)
Expand All @@ -106,8 +111,6 @@ def test_survival_boost_predict_proba(seed):
y_pred = est.predict_proba(X_test)
assert y_pred.shape == (n_events, X_test.shape[0], len(time_horizon))

# Check that passing `time_horizon` as a parameter to `predict_proba` is
# prioritized over the constructor parameter.
time_horizon = [0, 10, 20]
y_pred = est.predict_proba(X_test, time_horizon=time_horizon)
assert y_pred.shape == (n_events, X_test.shape[0], len(time_horizon))
Expand Down

0 comments on commit 9f9f750

Please sign in to comment.