diff --git a/hazardous/_survival_boost.py b/hazardous/_survival_boost.py index b85bdef..0165c8c 100644 --- a/hazardous/_survival_boost.py +++ b/hazardous/_survival_boost.py @@ -280,17 +280,18 @@ def fit(self, X, y, times=None): def predict_proba(self, X, time_horizon=None): """Estimate the probability of all incidences for a specific time horizon. - See the docstring for the `time_horizon` parameter for more details. - - Returns a (n_events + 1)d array with shape (X.shape[0], n_events + 1). - The first column holds the survival probability to any event and others the - incicence probabilities for each event. - Parameters ---------- X : array-like of shape (n_samples, n_features) The input samples. - time_horizon : float or list, default=None + time_horizon : float or array-like, default=None + The time horizon at which to estimate the probabilities. If `None`, the + `time_horizon` passed at the constructor is used. + + Returns + ------- + incidence_probabilities : ndarray of shape (n_events + 1, n_samples, n_times) + The incidence probabilities for each event at the given time horizon. """ if time_horizon is None: if self.time_horizon is None: