diff --git a/hazardous/tests/test_survival_boost.py b/hazardous/tests/test_survival_boost.py index b0749ce..e8b3518 100644 --- a/hazardous/tests/test_survival_boost.py +++ b/hazardous/tests/test_survival_boost.py @@ -70,6 +70,15 @@ def test_survival_boost_incidence_and_survival(seed): @pytest.mark.parametrize("seed", SEED_RANGE) def test_survival_boost_predict_proba(seed): + """Check the behaviour of the `predict_proba` method. Notably, we check: + + - we raise an error when no `time_horizon` is set in the constructor nor + passed to `predict_proba`. + - we can pass `time_horizon` as a parameter to `predict_proba`. + - we can set `time_horizon` as a constructor parameter of `SurvivalBoost`. + - the `time_horizon` parameter passed to `predict_proba` is prioritized + over the constructor parameter. + """ X, y = make_synthetic_competing_weibull(return_X_y=True, random_state=seed) assert sorted(y["event"].unique()) == [0, 1, 2, 3] n_events = 4 @@ -78,15 +87,12 @@ def test_survival_boost_predict_proba(seed): est = SurvivalBoost(n_iter=3, show_progressbar=False, random_state=seed) est.fit(X_train, y_train) - # Raise an error when `time_horizon` was not set in the constructor nor in - # passed to `predict_proba`. err_msg = ( "The time_horizon parameter is required to use SurvivalBoost as a classifier." ) with pytest.raises(ValueError, match=err_msg): est.predict_proba(X_test) - # Passing `time_horizon` as a parameter to `predict_proba`. time_horizon = 0 y_pred = est.predict_proba(X_test, time_horizon=time_horizon) assert y_pred.shape == (n_events, X_test.shape[0], 1) @@ -95,7 +101,6 @@ def test_survival_boost_predict_proba(seed): y_pred = est.predict_proba(X_test, time_horizon=time_horizon) assert y_pred.shape == (n_events, X_test.shape[0], len(time_horizon)) - # Setting `time_horizon` as a constructor parameter of `SurvivalBoost`. time_horizon = 0 est.set_params(time_horizon=time_horizon) y_pred = est.predict_proba(X_test) @@ -106,8 +111,6 @@ def test_survival_boost_predict_proba(seed): y_pred = est.predict_proba(X_test) assert y_pred.shape == (n_events, X_test.shape[0], len(time_horizon)) - # Check that passing `time_horizon` as a parameter to `predict_proba` is - # prioritized over the constructor parameter. time_horizon = [0, 10, 20] y_pred = est.predict_proba(X_test, time_horizon=time_horizon) assert y_pred.shape == (n_events, X_test.shape[0], len(time_horizon))