Skip to content

Commit

Permalink
Merge pull request #83 from mayer79/split_quant_approx
Browse files Browse the repository at this point in the history
split quant_approx into two arguments
  • Loading branch information
mayer79 authored Oct 20, 2023
2 parents 7a530da + d605707 commit aba883f
Show file tree
Hide file tree
Showing 10 changed files with 168 additions and 144 deletions.
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Major changes

- `hstats()` has received an argument `quant_approx` to speed-up calculations by quantile binning. Dense numeric variables are replaced by midpoints of `quant_approx + 1` uniform quantiles. By default, the value is `NULL` (no approximation). Even relatively high values like 50 will bring a massive speed-up for dense features, mainly for the one-way calculations. Use this option when calculations are slow, or when you want to increase `n_max`.
- Quantile approximation: `hstats()` now has the option `approx = FALSE`. Set to `TRUE` to replace values of dense numeric columns by `grid_size = 50` quantile midpoints. This will bring a massive speed-up for one-way calculations. Use this option when one-way calculations are slow, or when you want to increase `n_max`.
- `hstats()`: `n_max` has been increased from 300 to 500 rows. This will make estimates of H statistics more stable at the price of longer run time. Reduce to 300 for the old behaviour.
- `hstats()`: Three-way interactions are not anymore calculated by default. Set `threeway_m` to 5 for the old behaviour.
- Revised plots: The colors and color palettes have changed and can now also be controlled via global options. For instance, to change the fill color of all bars, set `options(hstats.fill = new value)`. Value labels are more clear, and there are more options. Varying color/fill scales now use viridis (inferno). This can be modified on the fly or via `options(hstats.viridis_args = list(...))`.
Expand Down
97 changes: 31 additions & 66 deletions R/hstats.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,16 @@
#' @param threeway_m Like `pairwise_m`, but controls the feature count for
#' three-way interactions. Cannot be larger than `pairwise_m`.
#' To save computation time, the default is 0.
#' @param quant_approx Integer. Dense numeric variables in `X` are replaced by midpoints
#' of `quant_approx + 1` uniform quantiles. By default, the value is `NULL`
#' (no approximation). Even relatively high values like 50 will bring a massive
#' speed-up for dense features, mainly for one-way statistics.
#' Note that the quantiles are calculated after subsampling to `n_max` rows.
#' @param eps Threshold below which numerator values are set to 0. Default is 1e-10.
#' @param approx Should quantile approximation be applied to dense numeric features?
#' The default is `FALSE`. Setting this option to `TRUE` brings a massive speed-up
#' for one-way calculations. It can, e.g., be used when the number of features is
#' very large.
#' @param grid_size Integer controlling the number of quantile midpoints used to
#' approximate dense numerics. The quantile midpoints are calculated after
#' subampling via `n_max`. Only relevant if `approx = TRUE`.
#' @param n_max If `X` has more than `n_max` rows, a random sample of `n_max` rows is
#' selected from `X`. In this case, set a random seed for reproducibility.
#' @param eps Threshold below which numerator values are set to 0. Default is 1e-10.
#' @param w Optional vector of case weights. Can also be a column name of `X`.
#' @param verbose Should a progress bar be shown? The default is `TRUE`.
#' @param ... Additional arguments passed to `pred_fun(object, X, ...)`,
Expand Down Expand Up @@ -141,8 +143,9 @@ hstats <- function(object, ...) {
hstats.default <- function(object, X, v = NULL,
pred_fun = stats::predict,
pairwise_m = 5L, threeway_m = 0L,
quant_approx = NULL, eps = 1e-10,
n_max = 500L, w = NULL, verbose = TRUE, ...) {
approx = FALSE, grid_size = 50L,
n_max = 500L, eps = 1e-10,
w = NULL, verbose = TRUE, ...) {
stopifnot(
is.matrix(X) || is.data.frame(X),
is.function(pred_fun)
Expand Down Expand Up @@ -180,8 +183,8 @@ hstats.default <- function(object, X, v = NULL,
}

# Quantile approximation to speedup things for dense features
if (!is.null(quant_approx)) {
X <- approx_matrix_or_df(X = X, v = v, m = quant_approx)
if (isTRUE(approx)) {
X <- approx_matrix_or_df(X = X, v = v, m = grid_size)
}

# Predictions ("F" in Friedman and Popescu) always calculated (cheap)
Expand Down Expand Up @@ -277,18 +280,20 @@ hstats.default <- function(object, X, v = NULL,
hstats.ranger <- function(object, X, v = NULL,
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
pairwise_m = 5L, threeway_m = 0L,
quant_approx = NULL, eps = 1e-10,
n_max = 500L, w = NULL, verbose = TRUE, ...) {
approx = FALSE, grid_size = 50L,
n_max = 500L, eps = 1e-10,
w = NULL, verbose = TRUE, ...) {
hstats.default(
object = object,
X = X,
v = v,
pred_fun = pred_fun,
pairwise_m = pairwise_m,
threeway_m = threeway_m,
quant_approx = quant_approx,
eps = eps,
approx = approx,
grid_size = grid_size,
n_max = n_max,
eps = eps,
w = w,
verbose = verbose,
...
Expand All @@ -300,8 +305,9 @@ hstats.ranger <- function(object, X, v = NULL,
hstats.Learner <- function(object, X, v = NULL,
pred_fun = NULL,
pairwise_m = 5L, threeway_m = 0L,
quant_approx = NULL, eps = 1e-10,
n_max = 500L, w = NULL, verbose = TRUE, ...) {
approx = FALSE, grid_size = 50L,
n_max = 500L, eps = 1e-10,
w = NULL, verbose = TRUE, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
Expand All @@ -312,9 +318,10 @@ hstats.Learner <- function(object, X, v = NULL,
pred_fun = pred_fun,
pairwise_m = pairwise_m,
threeway_m = threeway_m,
quant_approx = quant_approx,
eps = eps,
approx = approx,
grid_size = grid_size,
n_max = n_max,
eps = eps,
w = w,
verbose = verbose,
...
Expand All @@ -327,19 +334,20 @@ hstats.explainer <- function(object, X = object[["data"]],
v = NULL,
pred_fun = object[["predict_function"]],
pairwise_m = 5L, threeway_m = 0L,
quant_approx = NULL, eps = 1e-10,
n_max = 500L, w = object[["weights"]],
verbose = TRUE, ...) {
approx = FALSE, grid_size = 50L,
n_max = 500L, eps = 1e-10,
w = object[["weights"]], verbose = TRUE, ...) {
hstats.default(
object = object[["model"]],
X = X,
v = v,
pred_fun = pred_fun,
pairwise_m = pairwise_m,
threeway_m = threeway_m,
quant_approx = quant_approx,
eps = eps,
approx = approx,
grid_size = grid_size,
n_max = n_max,
eps = eps,
w = w,
verbose = verbose,
...
Expand Down Expand Up @@ -548,46 +556,3 @@ get_v <- function(H, m) {
}
v[v %in% v_cand]
}

#' Approximate Vector
#'
#' Internal function. Approximates values by the average of the two closest quantiles.
#'
#' @noRd
#' @keywords internal
#'
#' @param x A vector or factor.
#' @param m Number of unique values.
#' @returns An approximation of `x` (or `x` if non-numeric or discrete).
approx_vector <- function(x, m = 25L) {
if (!is.numeric(x) || length(unique(x)) <= m) {
return(x)
}
p <- seq(0, 1, length.out = m + 1L)
q <- unique(stats::quantile(x, probs = p, names = FALSE, na.rm = TRUE))
mids <- (q[-length(q)] + q[-1L]) / 2
return(mids[findInterval(x, q, rightmost.closed = TRUE)])
}

#' Approximate df or Matrix
#'
#' Internal function. Calls `approx_vector()` to each column in matrix or data.frame.
#'
#' @noRd
#' @keywords internal
#'
#' @param X A matrix or data.frame.
#' @param m Number of unique values.
#' @returns An approximation of `X` (or `X` if non-numeric or discrete).
approx_matrix_or_df <- function(X, v = colnames(X), m = 25L) {
stopifnot(
m >= 2L,
is.data.frame(X) || is.matrix(X)
)
if (is.data.frame(X)) {
X[v] <- lapply(X[v], FUN = approx_vector, m = m)
} else { # Matrix
X[, v] <- apply(X[, v, drop = FALSE], MARGIN = 2L, FUN = approx_vector, m = m)
}
return(X)
}
2 changes: 1 addition & 1 deletion R/partial_dep.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
#' A partial dependence plot (PDP) plots the values of \eqn{\hat F_s(\mathbf{x}_s)}
#' over a grid of evaluation points \eqn{\mathbf{x}_s}.
#'
#' @inheritParams hstats
#' @inheritParams multivariate_grid
#' @inheritParams hstats
#' @param v One or more column names over which you want to calculate the partial
#' dependence.
#' @param grid Evaluation grid. A vector (if `length(v) == 1L`), or a matrix/data.frame
Expand Down
59 changes: 59 additions & 0 deletions R/utils_calculate.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,62 @@ wcenter <- function(x, w = NULL) {
# sweep(x, MARGIN = 2L, STATS = wcolMeans(x, w = w)) # Slower
x - matrix(wcolMeans(x, w = w), nrow = nrow(x), ncol = ncol(x), byrow = TRUE)
}

#' Bin into Quantiles
#'
#' Internal function. Applies [cut()] to quantile breaks.
#'
#' @noRd
#' @keywords internal
#'
#' @param x A numeric vector.
#' @param m Number of intervals.
#' @returns A factor, representing binned `x`.
qcut <- function(x, m) {
p <- seq(0, 1, length.out = m + 1L)
g <- stats::quantile(x, probs = p, names = FALSE, type = 1L, na.rm = TRUE)
cut(x, breaks = unique(g), include.lowest = TRUE)
}

#' Approximate Vector
#'
#' Internal function. Approximates values by the average of the two closest quantiles.
#'
#' @noRd
#' @keywords internal
#'
#' @param x A vector or factor.
#' @param m Number of unique values.
#' @returns An approximation of `x` (or `x` if non-numeric or discrete).
approx_vector <- function(x, m = 50L) {
if (!is.numeric(x) || length(unique(x)) <= m) {
return(x)
}
p <- seq(0, 1, length.out = m + 1L)
q <- unique(stats::quantile(x, probs = p, names = FALSE, na.rm = TRUE))
mids <- (q[-length(q)] + q[-1L]) / 2
return(mids[findInterval(x, q, rightmost.closed = TRUE)])
}

#' Approximate df or Matrix
#'
#' Internal function. Calls `approx_vector()` to each column in matrix or data.frame.
#'
#' @noRd
#' @keywords internal
#'
#' @param X A matrix or data.frame.
#' @param m Number of unique values.
#' @returns An approximation of `X` (or `X` if non-numeric or discrete).
approx_matrix_or_df <- function(X, v = colnames(X), m = 50L) {
stopifnot(
m >= 2L,
is.data.frame(X) || is.matrix(X)
)
if (is.data.frame(X)) {
X[v] <- lapply(X[v], FUN = approx_vector, m = m)
} else { # Matrix
X[, v] <- apply(X[, v, drop = FALSE], MARGIN = 2L, FUN = approx_vector, m = m)
}
return(X)
}
16 changes: 0 additions & 16 deletions R/utils_input.R
Original file line number Diff line number Diff line change
@@ -1,19 +1,3 @@
#' Bin into Quantiles
#'
#' Internal function. Applies [cut()] to quantile breaks.
#'
#' @noRd
#' @keywords internal
#'
#' @param x A numeric vector.
#' @param m Number of intervals.
#' @returns A factor, representing binned `x`.
qcut <- function(x, m) {
p <- seq(0, 1, length.out = m + 1L)
g <- stats::quantile(x, probs = p, names = FALSE, type = 1L, na.rm = TRUE)
cut(x, breaks = unique(g), include.lowest = TRUE)
}

#' Prepares Group BY Variable
#'
#' Internal function that prepares a BY variable or BY column name.
Expand Down
35 changes: 21 additions & 14 deletions man/hstats.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion packaging.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ build()
# build(binary = TRUE)
install(upgrade = FALSE)


# Run only if package is public(!) and should go to CRAN
if (FALSE) {
check_win_devel()
Expand Down
Loading

0 comments on commit aba883f

Please sign in to comment.