diff --git a/NEWS.md b/NEWS.md index dba2f86..8ba829b 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,9 @@ # hstats 1.2.1 +## Usability + +- `ranger()` survival models now also work out-of-the-box without passing a tailored prediction function. Use the new argument `survival = "chf"` in `hstats()`, `ice()`, and `partial_dep()` to distinguish cumulative hazards (default) and survival probabilities ("prob") per time point. + ## Other changes - Fixed wrong ORCID of Michael. diff --git a/R/H2_overall.R b/R/H2_overall.R index 7f8ad2e..bd5ec36 100644 --- a/R/H2_overall.R +++ b/R/H2_overall.R @@ -80,8 +80,9 @@ h2_overall.default <- function(object, ...) { #' @describeIn h2_overall Overall interaction strength from "hstats" object. #' @export -h2_overall.hstats <- function(object, normalize = TRUE, squared = TRUE, - sort = TRUE, zero = TRUE, ...) { +h2_overall.hstats <- function( + object, normalize = TRUE, squared = TRUE, sort = TRUE, zero = TRUE, ... + ) { get_hstats_matrix( statistic = "h2_overall", object = object, @@ -113,3 +114,4 @@ h2_overall_raw <- function(x) { list(num = num, denom = x[["mean_f2"]]) } + diff --git a/R/H2_pairwise.R b/R/H2_pairwise.R index 46ff02c..6c3f3de 100644 --- a/R/H2_pairwise.R +++ b/R/H2_pairwise.R @@ -81,13 +81,14 @@ h2_pairwise.default <- function(object, ...) { #' @describeIn h2_pairwise Pairwise interaction strength from "hstats" object. #' @export -h2_pairwise.hstats <- function(object, normalize = TRUE, squared = TRUE, - sort = TRUE, zero = TRUE, ...) { +h2_pairwise.hstats <- function( + object, normalize = TRUE, squared = TRUE, sort = TRUE, zero = TRUE, ... + ) { get_hstats_matrix( statistic = "h2_pairwise", object = object, - normalize = normalize, - squared = squared, + normalize = normalize, + squared = squared, sort = sort, zero = zero ) @@ -122,3 +123,4 @@ h2_pairwise_raw <- function(x) { list(num = num, denom = denom) } + diff --git a/R/H2_threeway.R b/R/H2_threeway.R index 62677ac..dd4dfe0 100644 --- a/R/H2_threeway.R +++ b/R/H2_threeway.R @@ -65,13 +65,14 @@ h2_threeway.default <- function(object, ...) { #' @describeIn h2_threeway Pairwise interaction strength from "hstats" object. #' @export -h2_threeway.hstats <- function(object, normalize = TRUE, squared = TRUE, - sort = TRUE, zero = TRUE, ...) { +h2_threeway.hstats <- function( + object, normalize = TRUE, squared = TRUE, sort = TRUE, zero = TRUE, ... + ) { get_hstats_matrix( statistic = "h2_threeway", object = object, - normalize = normalize, - squared = squared, + normalize = normalize, + squared = squared, sort = sort, zero = zero ) @@ -109,3 +110,4 @@ h2_threeway_raw <- function(x) { list(num = num, denom = denom) } + diff --git a/R/average_loss.R b/R/average_loss.R index a1e322a..5a41788 100644 --- a/R/average_loss.R +++ b/R/average_loss.R @@ -66,12 +66,18 @@ average_loss <- function(object, ...) { #' @describeIn average_loss Default method. #' @export -average_loss.default <- function(object, X, y, - pred_fun = stats::predict, - loss = "squared_error", - agg_cols = FALSE, - BY = NULL, by_size = 4L, - w = NULL, ...) { +average_loss.default <- function( + object, + X, + y, + pred_fun = stats::predict, + loss = "squared_error", + agg_cols = FALSE, + BY = NULL, + by_size = 4L, + w = NULL, + ... + ) { stopifnot( is.matrix(X) || is.data.frame(X), is.function(pred_fun) @@ -109,13 +115,18 @@ average_loss.default <- function(object, X, y, #' @describeIn average_loss Method for "ranger" models. #' @export -average_loss.ranger <- function(object, X, y, - pred_fun = function(m, X, ...) - stats::predict(m, X, ...)$predictions, - loss = "squared_error", - agg_cols = FALSE, - BY = NULL, by_size = 4L, - w = NULL, ...) { +average_loss.ranger <- function( + object, + X, + y, + pred_fun = function(m, X, ...) + stats::predict(m, X, ...)$predictions, + loss = "squared_error", + agg_cols = FALSE, + BY = NULL, by_size = 4L, + w = NULL, + ... + ) { average_loss.default( object = object, X = X, @@ -132,16 +143,18 @@ average_loss.ranger <- function(object, X, y, #' @describeIn average_loss Method for DALEX "explainer". #' @export -average_loss.explainer <- function(object, - X = object[["data"]], - y = object[["y"]], - pred_fun = object[["predict_function"]], - loss = "squared_error", - agg_cols = FALSE, - BY = NULL, - by_size = 4L, - w = object[["weights"]], - ...) { +average_loss.explainer <- function( + object, + X = object[["data"]], + y = object[["y"]], + pred_fun = object[["predict_function"]], + loss = "squared_error", + agg_cols = FALSE, + BY = NULL, + by_size = 4L, + w = object[["weights"]], + ... + ) { average_loss.default( object = object[["model"]], X = X, @@ -155,3 +168,4 @@ average_loss.explainer <- function(object, ... ) } + diff --git a/R/hstats.R b/R/hstats.R index 5cbb756..1ee1bd8 100644 --- a/R/hstats.R +++ b/R/hstats.R @@ -53,6 +53,8 @@ #' @param eps Threshold below which numerator values are set to 0. Default is 1e-10. #' @param w Optional vector of case weights. Can also be a column name of `X`. #' @param verbose Should a progress bar be shown? The default is `TRUE`. +#' @param survival Should cumulative hazards ("chf", default) or survival +#' probabilities ("prob") per time be predicted? Only in `ranger()` survival models. #' @param ... Additional arguments passed to `pred_fun(object, X, ...)`, #' for instance `type = "response"` in a [glm()] model, or `reshape = TRUE` in a #' multiclass XGBoost model. @@ -140,12 +142,21 @@ hstats <- function(object, ...) { #' @describeIn hstats Default hstats method. #' @export -hstats.default <- function(object, X, v = NULL, - pred_fun = stats::predict, - pairwise_m = 5L, threeway_m = 0L, - approx = FALSE, grid_size = 50L, - n_max = 500L, eps = 1e-10, - w = NULL, verbose = TRUE, ...) { +hstats.default <- function( + object, + X, + v = NULL, + pred_fun = stats::predict, + pairwise_m = 5L, + threeway_m = 0L, + approx = FALSE, + grid_size = 50L, + n_max = 500L, + eps = 1e-10, + w = NULL, + verbose = TRUE, + ... + ) { stopifnot( is.matrix(X) || is.data.frame(X), is.function(pred_fun) @@ -277,12 +288,28 @@ hstats.default <- function(object, X, v = NULL, #' @describeIn hstats Method for "ranger" models. #' @export -hstats.ranger <- function(object, X, v = NULL, - pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions, - pairwise_m = 5L, threeway_m = 0L, - approx = FALSE, grid_size = 50L, - n_max = 500L, eps = 1e-10, - w = NULL, verbose = TRUE, ...) { +hstats.ranger <- function( + object, + X, + v = NULL, + pred_fun = NULL, + pairwise_m = 5L, + threeway_m = 0L, + approx = FALSE, + grid_size = 50L, + n_max = 500L, + eps = 1e-10, + w = NULL, + verbose = TRUE, + survival = c("chf", "prob"), + ... + ) { + survival <- match.arg(survival) + + if (is.null(pred_fun)) { + pred_fun <- pred_ranger + } + hstats.default( object = object, X = X, @@ -296,19 +323,28 @@ hstats.ranger <- function(object, X, v = NULL, eps = eps, w = w, verbose = verbose, + survival = survival, ... ) } #' @describeIn hstats Method for DALEX "explainer". #' @export -hstats.explainer <- function(object, X = object[["data"]], - v = NULL, - pred_fun = object[["predict_function"]], - pairwise_m = 5L, threeway_m = 0L, - approx = FALSE, grid_size = 50L, - n_max = 500L, eps = 1e-10, - w = object[["weights"]], verbose = TRUE, ...) { +hstats.explainer <- function( + object, + X = object[["data"]], + v = NULL, + pred_fun = object[["predict_function"]], + pairwise_m = 5L, + threeway_m = 0L, + approx = FALSE, + grid_size = 50L, + n_max = 500L, + eps = 1e-10, + w = object[["weights"]], + verbose = TRUE, + ... + ) { hstats.default( object = object[["model"]], X = X, @@ -353,8 +389,9 @@ print.hstats <- function(x, ...) { #' "h2", "h2_overall", "h2_pairwise", "h2_threeway", all of class "hstats_matrix". #' @export #' @seealso See [hstats()] for examples. -summary.hstats <- function(object, normalize = TRUE, squared = TRUE, - sort = TRUE, zero = TRUE, ...) { +summary.hstats <- function( + object, normalize = TRUE, squared = TRUE, sort = TRUE, zero = TRUE, ... + ) { args <- list( object = object, normalize = normalize, @@ -407,11 +444,21 @@ print.hstats_summary <- function(x, ...) { #' @returns An object of class "ggplot". #' @export #' @seealso See [hstats()] for examples. -plot.hstats <- function(x, which = 1:3, normalize = TRUE, squared = TRUE, - sort = TRUE, top_m = 15L, zero = TRUE, - fill = getOption("hstats.fill"), - viridis_args = getOption("hstats.viridis_args"), - facet_scales = "free", ncol = 2L, rotate_x = FALSE, ...) { +plot.hstats <- function( + x, + which = 1:3, + normalize = TRUE, + squared = TRUE, + sort = TRUE, + top_m = 15L, + zero = TRUE, + fill = getOption("hstats.fill"), + viridis_args = getOption("hstats.viridis_args"), + facet_scales = "free", + ncol = 2L, + rotate_x = FALSE, + ... + ) { if (is.null(viridis_args)) { viridis_args <- list() } @@ -477,8 +524,9 @@ plot.hstats <- function(x, which = 1:3, normalize = TRUE, squared = TRUE, #' @returns #' A list with a named list of feature combinations (pairs or triples), and #' corresponding centered partial dependencies. -mway <- function(object, v, X, pred_fun = stats::predict, w = NULL, - way = 2L, verb = TRUE, ...) { +mway <- function( + object, v, X, pred_fun = stats::predict, w = NULL, way = 2L, verb = TRUE, ... + ) { combs <- utils::combn(v, way, simplify = FALSE) n_combs <- length(combs) F_way <- vector("list", length = n_combs) @@ -528,3 +576,4 @@ get_v <- function(H, m) { } v[v %in% v_cand] } + diff --git a/R/ice.R b/R/ice.R index 8b36dba..919a25e 100644 --- a/R/ice.R +++ b/R/ice.R @@ -58,11 +58,20 @@ ice <- function(object, ...) { #' @describeIn ice Default method. #' @export -ice.default <- function(object, v, X, pred_fun = stats::predict, - BY = NULL, grid = NULL, grid_size = 49L, - trim = c(0.01, 0.99), - strategy = c("uniform", "quantile"), na.rm = TRUE, - n_max = 100L, ...) { +ice.default <- function( + object, + v, + X, + pred_fun = stats::predict, + BY = NULL, + grid = NULL, + grid_size = 49L, + trim = c(0.01, 0.99), + strategy = c("uniform", "quantile"), + na.rm = TRUE, + n_max = 100L, + ... + ) { stopifnot( is.matrix(X) || is.data.frame(X), is.function(pred_fun), @@ -150,12 +159,27 @@ ice.default <- function(object, v, X, pred_fun = stats::predict, #' @describeIn ice Method for "ranger" models. #' @export -ice.ranger <- function(object, v, X, - pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions, - BY = NULL, grid = NULL, grid_size = 49L, - trim = c(0.01, 0.99), - strategy = c("uniform", "quantile"), na.rm = TRUE, - n_max = 100L, ...) { +ice.ranger <- function( + object, + v, + X, + pred_fun = NULL, + BY = NULL, + grid = NULL, + grid_size = 49L, + trim = c(0.01, 0.99), + strategy = c("uniform", "quantile"), + na.rm = TRUE, + n_max = 100L, + survival = c("chf", "prob"), + ... + ) { + survival <- match.arg(survival) + + if (is.null(pred_fun)) { + pred_fun <- pred_ranger + } + ice.default( object = object, v = v, @@ -168,18 +192,27 @@ ice.ranger <- function(object, v, X, strategy = strategy, na.rm = na.rm, n_max = n_max, + survival = survival, ... ) } #' @describeIn ice Method for DALEX "explainer". #' @export -ice.explainer <- function(object, v = v, X = object[["data"]], - pred_fun = object[["predict_function"]], - BY = NULL, grid = NULL, grid_size = 49L, - trim = c(0.01, 0.99), - strategy = c("uniform", "quantile"), na.rm = TRUE, - n_max = 100L, ...) { +ice.explainer <- function( + object, + v = v, + X = object[["data"]], + pred_fun = object[["predict_function"]], + BY = NULL, + grid = NULL, + grid_size = 49L, + trim = c(0.01, 0.99), + strategy = c("uniform", "quantile"), + na.rm = TRUE, + n_max = 100L, + ... + ) { ice.default( object = object[["model"]], v = v, @@ -226,12 +259,17 @@ print.ice <- function(x, n = 3L, ...) { #' @export #' @returns An object of class "ggplot". #' @seealso See [ice()] for examples. -plot.ice <- function(x, center = FALSE, alpha = 0.2, - color = getOption("hstats.color"), - swap_dim = FALSE, - viridis_args = getOption("hstats.viridis_args"), - facet_scales = "fixed", - rotate_x = FALSE, ...) { +plot.ice <- function( + x, + center = FALSE, + alpha = 0.2, + color = getOption("hstats.color"), + swap_dim = FALSE, + viridis_args = getOption("hstats.viridis_args"), + facet_scales = "fixed", + rotate_x = FALSE, + ... + ) { v <- x[["v"]] K <- x[["K"]] data <- x[["data"]] @@ -292,3 +330,4 @@ plot.ice <- function(x, center = FALSE, alpha = 0.2, } p } + diff --git a/R/losses.R b/R/losses.R index f9b40de..096b779 100644 --- a/R/losses.R +++ b/R/losses.R @@ -205,3 +205,4 @@ get_loss_fun <- function(loss) { stop("Unknown loss function.") ) } + diff --git a/R/onLoad.R b/R/onLoad.R index 2382dd1..b2506de 100644 --- a/R/onLoad.R +++ b/R/onLoad.R @@ -14,3 +14,4 @@ # Fix undefined global variable note utils::globalVariables(c("varying_", "value_", "id_", "variable_", "obs_", "error_")) + diff --git a/R/partial_dep.R b/R/partial_dep.R index 7e7ca03..04fa534 100644 --- a/R/partial_dep.R +++ b/R/partial_dep.R @@ -104,11 +104,22 @@ partial_dep <- function(object, ...) { #' @describeIn partial_dep Default method. #' @export -partial_dep.default <- function(object, v, X, pred_fun = stats::predict, - BY = NULL, by_size = 4L, grid = NULL, grid_size = 49L, - trim = c(0.01, 0.99), - strategy = c("uniform", "quantile"), na.rm = TRUE, - n_max = 1000L, w = NULL, ...) { +partial_dep.default <- function( + object, + v, + X, + pred_fun = stats::predict, + BY = NULL, + by_size = 4L, + grid = NULL, + grid_size = 49L, + trim = c(0.01, 0.99), + strategy = c("uniform", "quantile"), + na.rm = TRUE, + n_max = 1000L, + w = NULL, + ... + ) { stopifnot( is.matrix(X) || is.data.frame(X), is.function(pred_fun), @@ -189,12 +200,29 @@ partial_dep.default <- function(object, v, X, pred_fun = stats::predict, #' @describeIn partial_dep Method for "ranger" models. #' @export -partial_dep.ranger <- function(object, v, X, - pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions, - BY = NULL, by_size = 4L, grid = NULL, grid_size = 49L, - trim = c(0.01, 0.99), - strategy = c("uniform", "quantile"), na.rm = TRUE, - n_max = 1000L, w = NULL, ...) { +partial_dep.ranger <- function( + object, + v, + X, + pred_fun = NULL, + BY = NULL, + by_size = 4L, + grid = NULL, + grid_size = 49L, + trim = c(0.01, 0.99), + strategy = c("uniform", "quantile"), + na.rm = TRUE, + n_max = 1000L, + w = NULL, + survival = c("chf", "prob"), + ... + ) { + survival <- match.arg(survival) + + if (is.null(pred_fun)) { + pred_fun <- pred_ranger + } + partial_dep.default( object = object, v = v, @@ -209,18 +237,29 @@ partial_dep.ranger <- function(object, v, X, na.rm = na.rm, n_max = n_max, w = w, + survival = survival, ... ) } #' @describeIn partial_dep Method for DALEX "explainer". #' @export -partial_dep.explainer <- function(object, v, X = object[["data"]], - pred_fun = object[["predict_function"]], - BY = NULL, by_size = 4L, grid = NULL, grid_size = 49L, - trim = c(0.01, 0.99), - strategy = c("uniform", "quantile"), na.rm = TRUE, - n_max = 1000L, w = object[["weights"]], ...) { +partial_dep.explainer <- function( + object, + v, + X = object[["data"]], + pred_fun = object[["predict_function"]], + BY = NULL, + by_size = 4L, + grid = NULL, + grid_size = 49L, + trim = c(0.01, 0.99), + strategy = c("uniform", "quantile"), + na.rm = TRUE, + n_max = 1000L, + w = object[["weights"]], + ... + ) { partial_dep.default( object = object[["model"]], v = v, @@ -278,13 +317,17 @@ print.partial_dep <- function(x, n = 3L, ...) { #' @export #' @returns An object of class "ggplot". #' @seealso See [partial_dep()] for examples. -plot.partial_dep <- function(x, - color = getOption("hstats.color"), - swap_dim = FALSE, - viridis_args = getOption("hstats.viridis_args"), - facet_scales = "fixed", - rotate_x = FALSE, show_points = TRUE, - d2_geom = c("tile", "point", "line"), ...) { +plot.partial_dep <- function( + x, + color = getOption("hstats.color"), + swap_dim = FALSE, + viridis_args = getOption("hstats.viridis_args"), + facet_scales = "fixed", + rotate_x = FALSE, + show_points = TRUE, + d2_geom = c("tile", "point", "line"), + ... + ) { d2_geom <- match.arg(d2_geom) v <- x[["v"]] by_name <- x[["by_name"]] @@ -371,3 +414,4 @@ plot.partial_dep <- function(x, } p } + diff --git a/R/pd_importance.R b/R/pd_importance.R index 0a30d33..d144370 100644 --- a/R/pd_importance.R +++ b/R/pd_importance.R @@ -54,8 +54,9 @@ pd_importance.default <- function(object, ...) { #' @describeIn pd_importance PD based feature importance from "hstats" object. #' @export -pd_importance.hstats <- function(object, normalize = TRUE, squared = TRUE, - sort = TRUE, zero = TRUE, ...) { +pd_importance.hstats <- function( + object, normalize = TRUE, squared = TRUE, sort = TRUE, zero = TRUE, ... + ) { get_hstats_matrix( statistic = "pd_importance", object = object, @@ -84,3 +85,4 @@ pd_importance_raw <- function(x) { num <- .zap_small(num, eps = x[["eps"]]) # Numeric precision list(num = num, denom = x[["mean_f2"]]) } + diff --git a/R/pd_raw.R b/R/pd_raw.R index 2549e87..147708d 100644 --- a/R/pd_raw.R +++ b/R/pd_raw.R @@ -14,8 +14,17 @@ #' @returns #' A matrix of partial dependence values (one column per prediction dimension, #' one row per grid row, in the same order as `grid`). -pd_raw <- function(object, v, X, grid, pred_fun = stats::predict, - w = NULL, compress_X = TRUE, compress_grid = TRUE, ...) { +pd_raw <- function( + object, + v, + X, + grid, + pred_fun = stats::predict, + w = NULL, + compress_X = TRUE, + compress_grid = TRUE, + ... + ) { # Try different compressions if (compress_X && length(v) == ncol(X) - 1L) { # Removes duplicates in X[, not_v] and compensates via w @@ -56,8 +65,9 @@ pd_raw <- function(object, v, X, grid, pred_fun = stats::predict, #' but replicated over `X`). #' @returns #' Either a vector/matrix of predictions or a list with predictions and grid. -ice_raw <- function(object, v, X, grid, pred_fun = stats::predict, - pred_only = TRUE, ...) { +ice_raw <- function( + object, v, X, grid, pred_fun = stats::predict, pred_only = TRUE, ... + ) { D1 <- length(v) == 1L n <- nrow(X) n_grid <- NROW(grid) @@ -161,3 +171,4 @@ ice_raw <- function(object, v, X, grid, pred_fun = stats::predict, out[["reindex"]] <- match(grid, ugrid) out } + diff --git a/R/perm_importance.R b/R/perm_importance.R index 20fdcfb..6d5a798 100644 --- a/R/perm_importance.R +++ b/R/perm_importance.R @@ -56,12 +56,21 @@ perm_importance <- function(object, ...) { #' @describeIn perm_importance Default method. #' @export -perm_importance.default <- function(object, X, y, v = NULL, - pred_fun = stats::predict, - loss = "squared_error", - m_rep = 4L, agg_cols = FALSE, - normalize = FALSE, n_max = 10000L, - w = NULL, verbose = TRUE, ...) { +perm_importance.default <- function( + object, + X, + y, + v = NULL, + pred_fun = stats::predict, + loss = "squared_error", + m_rep = 4L, + agg_cols = FALSE, + normalize = FALSE, + n_max = 10000L, + w = NULL, + verbose = TRUE, + ... + ) { stopifnot( is.matrix(X) || is.data.frame(X), is.function(pred_fun), @@ -205,13 +214,21 @@ perm_importance.default <- function(object, X, y, v = NULL, #' @describeIn perm_importance Method for "ranger" models. #' @export -perm_importance.ranger <- function(object, X, y, v = NULL, - pred_fun = function(m, X, ...) - stats::predict(m, X, ...)$predictions, - loss = "squared_error", m_rep = 4L, - agg_cols = FALSE, - normalize = FALSE, n_max = 10000L, - w = NULL, verbose = TRUE, ...) { +perm_importance.ranger <- function( + object, + X, + y, + v = NULL, + pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions, + loss = "squared_error", + m_rep = 4L, + agg_cols = FALSE, + normalize = FALSE, + n_max = 10000L, + w = NULL, + verbose = TRUE, + ... + ) { perm_importance.default( object = object, X = X, @@ -231,19 +248,21 @@ perm_importance.ranger <- function(object, X, y, v = NULL, #' @describeIn perm_importance Method for DALEX "explainer". #' @export -perm_importance.explainer <- function(object, - X = object[["data"]], - y = object[["y"]], - v = NULL, - pred_fun = object[["predict_function"]], - loss = "squared_error", - m_rep = 4L, - agg_cols = FALSE, - normalize = FALSE, - n_max = 10000L, - w = object[["weights"]], - verbose = TRUE, - ...) { +perm_importance.explainer <- function( + object, + X = object[["data"]], + y = object[["y"]], + v = NULL, + pred_fun = object[["predict_function"]], + loss = "squared_error", + m_rep = 4L, + agg_cols = FALSE, + normalize = FALSE, + n_max = 10000L, + w = object[["weights"]], + verbose = TRUE, + ... + ) { perm_importance.default( object = object[["model"]], X = X, @@ -260,3 +279,4 @@ perm_importance.explainer <- function(object, ... ) } + diff --git a/R/utils_calculate.R b/R/utils_calculate.R index 8ee5882..3cd67b0 100644 --- a/R/utils_calculate.R +++ b/R/utils_calculate.R @@ -204,3 +204,4 @@ rep_rows <- function(x, i) { class(out) <- "data.frame" out } + diff --git a/R/utils_grid.R b/R/utils_grid.R index 5abc60e..dd370ba 100644 --- a/R/utils_grid.R +++ b/R/utils_grid.R @@ -35,8 +35,13 @@ #' x <- iris$Sepal.Width #' univariate_grid(x, grid_size = 5) # Uniform binning #' univariate_grid(x, grid_size = 5, strategy = "quantile") # Quantile -univariate_grid <- function(z, grid_size = 49L, trim = c(0.01, 0.99), - strategy = c("uniform", "quantile"), na.rm = TRUE) { +univariate_grid <- function( + z, + grid_size = 49L, + trim = c(0.01, 0.99), + strategy = c("uniform", "quantile"), + na.rm = TRUE + ) { strategy <- match.arg(strategy) uni <- unique(z) if (!is.numeric(z) || length(uni) <= grid_size) { @@ -80,8 +85,13 @@ univariate_grid <- function(z, grid_size = 49L, trim = c(0.01, 0.99), #' multivariate_grid(iris[1:2], grid_size = 4) #' multivariate_grid(iris$Species) # Works also in the univariate case #' @export -multivariate_grid <- function(x, grid_size = 49L, trim = c(0.01, 0.99), - strategy = c("uniform", "quantile"), na.rm = TRUE) { +multivariate_grid <- function( + x, + grid_size = 49L, + trim = c(0.01, 0.99), + strategy = c("uniform", "quantile"), + na.rm = TRUE + ) { strategy <- match.arg(strategy) p <- NCOL(x) if (p == 1L) { @@ -201,3 +211,4 @@ approx_matrix_or_df <- function(X, v = colnames(X), m = 50L) { } return(X) } + diff --git a/R/utils_input.R b/R/utils_input.R index f46257c..be5f15d 100644 --- a/R/utils_input.R +++ b/R/utils_input.R @@ -124,3 +124,30 @@ prepare_y <- function(y, X) { } list(y = y, y_names = y_names) } + +#' Predict Function for Ranger +#' +#' Internal function that prepares the predictions of different types of ranger models. +#' +#' @noRd +#' @keywords internal +#' @param model Fitted ranger model. +#' @param newdata Data to predict on. +#' @param survival Cumulative hazards "chf" (default) or probabilities "prob" per time. +#' @param ... Additional arguments passed to ranger's predict function. +#' +#' @returns A vector or matrix with predictions. +pred_ranger <- function(model, newdata, survival = c("chf", "prob"), ...) { + survival <- match.arg(survival) + + pred <- stats::predict(model, newdata, ...) + + if (model$treetype == "Survival") { + out <- if (survival == "chf") pred$chf else pred$survival + colnames(out) <- paste0("t", pred$unique.death.times) + } else { + out <- pred$predictions + } + return(out) +} + diff --git a/R/utils_plot.R b/R/utils_plot.R index a545f86..7c30b1d 100644 --- a/R/utils_plot.R +++ b/R/utils_plot.R @@ -99,3 +99,4 @@ mat2df <- function(mat, id = "Overall") { out <- cbind.data.frame(id_ = id, variable_ = factor(nm, levels = nm), mat) poor_man_stack(out, to_stack = pred_names) } + diff --git a/R/utils_statistics.R b/R/utils_statistics.R index 73655b4..591bc48 100644 --- a/R/utils_statistics.R +++ b/R/utils_statistics.R @@ -73,9 +73,14 @@ init_numerator <- function(x, way = 1L) { #' @param num Matrix with numerator statistics. #' @param denom Vector or matrix with denominator statistics. #' @returns Matrix of statistics, or `NULL`. -postprocess <- function(num, denom = rep(1, times = NCOL(num)), - normalize = TRUE, squared = TRUE, - sort = TRUE, zero = TRUE) { +postprocess <- function( + num, + denom = rep(1, times = NCOL(num)), + normalize = TRUE, + squared = TRUE, + sort = TRUE, + zero = TRUE + ) { stopifnot( is.matrix(num), is.matrix(denom) || is.vector(denom), # already covered by the next condition @@ -113,8 +118,9 @@ postprocess <- function(num, denom = rep(1, times = NCOL(num)), #' @param statistic Name of statistic as stored in "hstats" object. #' @inheritParams h2_overall #' @returns A character string. -get_hstats_matrix <- function(statistic, object, normalize = TRUE, squared = TRUE, - sort = TRUE, zero = TRUE, ...) { +get_hstats_matrix <- function( + statistic, object, normalize = TRUE, squared = TRUE, sort = TRUE, zero = TRUE, ... + ) { s <- object[[statistic]] if (!is.null(s)) { M <- postprocess( @@ -306,13 +312,18 @@ dimnames.hstats_matrix <- function(x) { #' @param ... Passed to [ggplot2::geom_bar()]. #' @export #' @returns An object of class "ggplot". -plot.hstats_matrix <- function(x, top_m = 15L, - fill = getOption("hstats.fill"), - swap_dim = FALSE, - viridis_args = getOption("hstats.viridis_args"), - facet_scales = "fixed", - ncol = 2L, rotate_x = FALSE, - err_type = c("SE", "SD", "No"), ...) { +plot.hstats_matrix <- function( + x, + top_m = 15L, + fill = getOption("hstats.fill"), + swap_dim = FALSE, + viridis_args = getOption("hstats.viridis_args"), + facet_scales = "fixed", + ncol = 2L, + rotate_x = FALSE, + err_type = c("SE", "SD", "No"), + ... + ) { err_type <- match.arg(err_type) M <- x[["M"]] if (is.null(M)) { diff --git a/man/hstats.Rd b/man/hstats.Rd index 6cbba29..40b9d3b 100644 --- a/man/hstats.Rd +++ b/man/hstats.Rd @@ -29,7 +29,7 @@ hstats(object, ...) object, X, v = NULL, - pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions, + pred_fun = NULL, pairwise_m = 5L, threeway_m = 0L, approx = FALSE, @@ -38,6 +38,7 @@ hstats(object, ...) eps = 1e-10, w = NULL, verbose = TRUE, + survival = c("chf", "prob"), ... ) @@ -103,6 +104,9 @@ selected from \code{X}. In this case, set a random seed for reproducibility.} \item{w}{Optional vector of case weights. Can also be a column name of \code{X}.} \item{verbose}{Should a progress bar be shown? The default is \code{TRUE}.} + +\item{survival}{Should cumulative hazards ("chf", default) or survival +probabilities ("prob") per time be predicted? Only in \code{ranger()} survival models.} } \value{ An object of class "hstats" containing these elements: diff --git a/man/ice.Rd b/man/ice.Rd index a27a158..d858c0c 100644 --- a/man/ice.Rd +++ b/man/ice.Rd @@ -28,7 +28,7 @@ ice(object, ...) object, v, X, - pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions, + pred_fun = NULL, BY = NULL, grid = NULL, grid_size = 49L, @@ -36,6 +36,7 @@ ice(object, ...) strategy = c("uniform", "quantile"), na.rm = TRUE, n_max = 100L, + survival = c("chf", "prob"), ... ) @@ -95,6 +96,9 @@ Either "uniform" or "quantile", see description of \code{\link[=univariate_grid] \item{n_max}{If \code{X} has more than \code{n_max} rows, a random sample of \code{n_max} rows is selected from \code{X}. In this case, set a random seed for reproducibility.} + +\item{survival}{Should cumulative hazards ("chf", default) or survival +probabilities ("prob") per time be predicted? Only in \code{ranger()} survival models.} } \value{ An object of class "ice" containing these elements: diff --git a/man/partial_dep.Rd b/man/partial_dep.Rd index 779cca8..c93e922 100644 --- a/man/partial_dep.Rd +++ b/man/partial_dep.Rd @@ -30,7 +30,7 @@ partial_dep(object, ...) object, v, X, - pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions, + pred_fun = NULL, BY = NULL, by_size = 4L, grid = NULL, @@ -40,6 +40,7 @@ partial_dep(object, ...) na.rm = TRUE, n_max = 1000L, w = NULL, + survival = c("chf", "prob"), ... ) @@ -109,6 +110,9 @@ Either "uniform" or "quantile", see description of \code{\link[=univariate_grid] selected from \code{X}. In this case, set a random seed for reproducibility.} \item{w}{Optional vector of case weights. Can also be a column name of \code{X}.} + +\item{survival}{Should cumulative hazards ("chf", default) or survival +probabilities ("prob") per time be predicted? Only in \code{ranger()} survival models.} } \value{ An object of class "partial_dep" containing these elements: