Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Direct ranger survival support #122

Merged
merged 2 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# hstats 1.2.1

## Usability

- `ranger()` survival models now also work out-of-the-box without passing a tailored prediction function. Use the new argument `survival = "chf"` in `hstats()`, `ice()`, and `partial_dep()` to distinguish cumulative hazards (default) and survival probabilities ("prob") per time point.

## Other changes

- Fixed wrong ORCID of Michael.
Expand Down
6 changes: 4 additions & 2 deletions R/H2_overall.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ h2_overall.default <- function(object, ...) {

#' @describeIn h2_overall Overall interaction strength from "hstats" object.
#' @export
h2_overall.hstats <- function(object, normalize = TRUE, squared = TRUE,
sort = TRUE, zero = TRUE, ...) {
h2_overall.hstats <- function(
object, normalize = TRUE, squared = TRUE, sort = TRUE, zero = TRUE, ...
) {
get_hstats_matrix(
statistic = "h2_overall",
object = object,
Expand Down Expand Up @@ -113,3 +114,4 @@ h2_overall_raw <- function(x) {

list(num = num, denom = x[["mean_f2"]])
}

10 changes: 6 additions & 4 deletions R/H2_pairwise.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,14 @@ h2_pairwise.default <- function(object, ...) {

#' @describeIn h2_pairwise Pairwise interaction strength from "hstats" object.
#' @export
h2_pairwise.hstats <- function(object, normalize = TRUE, squared = TRUE,
sort = TRUE, zero = TRUE, ...) {
h2_pairwise.hstats <- function(
object, normalize = TRUE, squared = TRUE, sort = TRUE, zero = TRUE, ...
) {
get_hstats_matrix(
statistic = "h2_pairwise",
object = object,
normalize = normalize,
squared = squared,
normalize = normalize,
squared = squared,
sort = sort,
zero = zero
)
Expand Down Expand Up @@ -122,3 +123,4 @@ h2_pairwise_raw <- function(x) {

list(num = num, denom = denom)
}

10 changes: 6 additions & 4 deletions R/H2_threeway.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,14 @@ h2_threeway.default <- function(object, ...) {

#' @describeIn h2_threeway Pairwise interaction strength from "hstats" object.
#' @export
h2_threeway.hstats <- function(object, normalize = TRUE, squared = TRUE,
sort = TRUE, zero = TRUE, ...) {
h2_threeway.hstats <- function(
object, normalize = TRUE, squared = TRUE, sort = TRUE, zero = TRUE, ...
) {
get_hstats_matrix(
statistic = "h2_threeway",
object = object,
normalize = normalize,
squared = squared,
normalize = normalize,
squared = squared,
sort = sort,
zero = zero
)
Expand Down Expand Up @@ -109,3 +110,4 @@ h2_threeway_raw <- function(x) {

list(num = num, denom = denom)
}

60 changes: 37 additions & 23 deletions R/average_loss.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,18 @@ average_loss <- function(object, ...) {

#' @describeIn average_loss Default method.
#' @export
average_loss.default <- function(object, X, y,
pred_fun = stats::predict,
loss = "squared_error",
agg_cols = FALSE,
BY = NULL, by_size = 4L,
w = NULL, ...) {
average_loss.default <- function(
object,
X,
y,
pred_fun = stats::predict,
loss = "squared_error",
agg_cols = FALSE,
BY = NULL,
by_size = 4L,
w = NULL,
...
) {
stopifnot(
is.matrix(X) || is.data.frame(X),
is.function(pred_fun)
Expand Down Expand Up @@ -109,13 +115,18 @@ average_loss.default <- function(object, X, y,

#' @describeIn average_loss Method for "ranger" models.
#' @export
average_loss.ranger <- function(object, X, y,
pred_fun = function(m, X, ...)
stats::predict(m, X, ...)$predictions,
loss = "squared_error",
agg_cols = FALSE,
BY = NULL, by_size = 4L,
w = NULL, ...) {
average_loss.ranger <- function(
object,
X,
y,
pred_fun = function(m, X, ...)
stats::predict(m, X, ...)$predictions,
loss = "squared_error",
agg_cols = FALSE,
BY = NULL, by_size = 4L,
w = NULL,
...
) {
average_loss.default(
object = object,
X = X,
Expand All @@ -132,16 +143,18 @@ average_loss.ranger <- function(object, X, y,

#' @describeIn average_loss Method for DALEX "explainer".
#' @export
average_loss.explainer <- function(object,
X = object[["data"]],
y = object[["y"]],
pred_fun = object[["predict_function"]],
loss = "squared_error",
agg_cols = FALSE,
BY = NULL,
by_size = 4L,
w = object[["weights"]],
...) {
average_loss.explainer <- function(
object,
X = object[["data"]],
y = object[["y"]],
pred_fun = object[["predict_function"]],
loss = "squared_error",
agg_cols = FALSE,
BY = NULL,
by_size = 4L,
w = object[["weights"]],
...
) {
average_loss.default(
object = object[["model"]],
X = X,
Expand All @@ -155,3 +168,4 @@ average_loss.explainer <- function(object,
...
)
}

105 changes: 77 additions & 28 deletions R/hstats.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
#' @param eps Threshold below which numerator values are set to 0. Default is 1e-10.
#' @param w Optional vector of case weights. Can also be a column name of `X`.
#' @param verbose Should a progress bar be shown? The default is `TRUE`.
#' @param survival Should cumulative hazards ("chf", default) or survival
#' probabilities ("prob") per time be predicted? Only in `ranger()` survival models.
#' @param ... Additional arguments passed to `pred_fun(object, X, ...)`,
#' for instance `type = "response"` in a [glm()] model, or `reshape = TRUE` in a
#' multiclass XGBoost model.
Expand Down Expand Up @@ -140,12 +142,21 @@ hstats <- function(object, ...) {

#' @describeIn hstats Default hstats method.
#' @export
hstats.default <- function(object, X, v = NULL,
pred_fun = stats::predict,
pairwise_m = 5L, threeway_m = 0L,
approx = FALSE, grid_size = 50L,
n_max = 500L, eps = 1e-10,
w = NULL, verbose = TRUE, ...) {
hstats.default <- function(
object,
X,
v = NULL,
pred_fun = stats::predict,
pairwise_m = 5L,
threeway_m = 0L,
approx = FALSE,
grid_size = 50L,
n_max = 500L,
eps = 1e-10,
w = NULL,
verbose = TRUE,
...
) {
stopifnot(
is.matrix(X) || is.data.frame(X),
is.function(pred_fun)
Expand Down Expand Up @@ -277,12 +288,28 @@ hstats.default <- function(object, X, v = NULL,

#' @describeIn hstats Method for "ranger" models.
#' @export
hstats.ranger <- function(object, X, v = NULL,
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
pairwise_m = 5L, threeway_m = 0L,
approx = FALSE, grid_size = 50L,
n_max = 500L, eps = 1e-10,
w = NULL, verbose = TRUE, ...) {
hstats.ranger <- function(
object,
X,
v = NULL,
pred_fun = NULL,
pairwise_m = 5L,
threeway_m = 0L,
approx = FALSE,
grid_size = 50L,
n_max = 500L,
eps = 1e-10,
w = NULL,
verbose = TRUE,
survival = c("chf", "prob"),
...
) {
survival <- match.arg(survival)

if (is.null(pred_fun)) {
pred_fun <- pred_ranger
}

hstats.default(
object = object,
X = X,
Expand All @@ -296,19 +323,28 @@ hstats.ranger <- function(object, X, v = NULL,
eps = eps,
w = w,
verbose = verbose,
survival = survival,
...
)
}

#' @describeIn hstats Method for DALEX "explainer".
#' @export
hstats.explainer <- function(object, X = object[["data"]],
v = NULL,
pred_fun = object[["predict_function"]],
pairwise_m = 5L, threeway_m = 0L,
approx = FALSE, grid_size = 50L,
n_max = 500L, eps = 1e-10,
w = object[["weights"]], verbose = TRUE, ...) {
hstats.explainer <- function(
object,
X = object[["data"]],
v = NULL,
pred_fun = object[["predict_function"]],
pairwise_m = 5L,
threeway_m = 0L,
approx = FALSE,
grid_size = 50L,
n_max = 500L,
eps = 1e-10,
w = object[["weights"]],
verbose = TRUE,
...
) {
hstats.default(
object = object[["model"]],
X = X,
Expand Down Expand Up @@ -353,8 +389,9 @@ print.hstats <- function(x, ...) {
#' "h2", "h2_overall", "h2_pairwise", "h2_threeway", all of class "hstats_matrix".
#' @export
#' @seealso See [hstats()] for examples.
summary.hstats <- function(object, normalize = TRUE, squared = TRUE,
sort = TRUE, zero = TRUE, ...) {
summary.hstats <- function(
object, normalize = TRUE, squared = TRUE, sort = TRUE, zero = TRUE, ...
) {
args <- list(
object = object,
normalize = normalize,
Expand Down Expand Up @@ -407,11 +444,21 @@ print.hstats_summary <- function(x, ...) {
#' @returns An object of class "ggplot".
#' @export
#' @seealso See [hstats()] for examples.
plot.hstats <- function(x, which = 1:3, normalize = TRUE, squared = TRUE,
sort = TRUE, top_m = 15L, zero = TRUE,
fill = getOption("hstats.fill"),
viridis_args = getOption("hstats.viridis_args"),
facet_scales = "free", ncol = 2L, rotate_x = FALSE, ...) {
plot.hstats <- function(
x,
which = 1:3,
normalize = TRUE,
squared = TRUE,
sort = TRUE,
top_m = 15L,
zero = TRUE,
fill = getOption("hstats.fill"),
viridis_args = getOption("hstats.viridis_args"),
facet_scales = "free",
ncol = 2L,
rotate_x = FALSE,
...
) {
if (is.null(viridis_args)) {
viridis_args <- list()
}
Expand Down Expand Up @@ -477,8 +524,9 @@ plot.hstats <- function(x, which = 1:3, normalize = TRUE, squared = TRUE,
#' @returns
#' A list with a named list of feature combinations (pairs or triples), and
#' corresponding centered partial dependencies.
mway <- function(object, v, X, pred_fun = stats::predict, w = NULL,
way = 2L, verb = TRUE, ...) {
mway <- function(
object, v, X, pred_fun = stats::predict, w = NULL, way = 2L, verb = TRUE, ...
) {
combs <- utils::combn(v, way, simplify = FALSE)
n_combs <- length(combs)
F_way <- vector("list", length = n_combs)
Expand Down Expand Up @@ -528,3 +576,4 @@ get_v <- function(H, m) {
}
v[v %in% v_cand]
}

Loading
Loading