From 76d7c816b5af1b85b00433fc108ccefca9796c86 Mon Sep 17 00:00:00 2001 From: Vincent M Date: Fri, 26 Jan 2024 17:14:20 +0100 Subject: [PATCH] Add survival proba along with events in SurvTRACE (#45) --- hazardous/survtrace/_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/hazardous/survtrace/_model.py b/hazardous/survtrace/_model.py index 46ddc01..3e0f71b 100644 --- a/hazardous/survtrace/_model.py +++ b/hazardous/survtrace/_model.py @@ -299,7 +299,9 @@ def predict_survival_function(self, X): return surv def predict_cumulative_incidence(self, X): - return 1 - self.predict_survival_function(X) + risks = 1 - self.predict_survival_function(X) + surv = (1 - risks.sum(axis=0))[None, :, :] + return np.concatenate([surv, risks], axis=0) class _SurvTRACEModule(nn.Module):