From 14cfac1f9e7d5d2865b208610f5235f2fa560fe8 Mon Sep 17 00:00:00 2001 From: Michael Mayer Date: Thu, 28 Dec 2023 18:06:16 +0100 Subject: [PATCH] Add argument color_numeric to potential_interactions() --- R/sv_dependence.R | 57 +++++++++++++++++--- tests/testthat/test-potential_interactions.R | 41 +++++++++++++- 2 files changed, 89 insertions(+), 9 deletions(-) diff --git a/R/sv_dependence.R b/R/sv_dependence.R index 5d66265..465eb37 100644 --- a/R/sv_dependence.R +++ b/R/sv_dependence.R @@ -208,10 +208,13 @@ sv_dependence.mshapviz <- function(object, v, color_var = "auto", color = "#3b52 #' @param n_bins A numeric `v` with more than `n_bins` unique values is binned #' into that many quantile bins. If `NULL` (default), `n_bins` equals the smaller #' of \eqn{n/20} and \eqn{\sqrt n} (rounded up), where \eqn{n} is the sample size. +#' Ignored if `obj` contains SHAP interactions. +#' @param color_numeric Should color feature values be converted to numeric? Default is +#' `TRUE`. Ignored if `obj` contains SHAP interactions. #' @returns A named vector of decreasing interaction strengths. #' @export #' @seealso [sv_dependence()] -potential_interactions <- function(obj, v, n_bins = NULL) { +potential_interactions <- function(obj, v, n_bins = NULL, color_numeric = TRUE) { stopifnot( is.shapviz(obj), is.null(n_bins) || (length(n_bins) == 1L && n_bins >= 1L) @@ -232,19 +235,20 @@ potential_interactions <- function(obj, v, n_bins = NULL) { return(sort(2 * colMeans(abs(S_inter[, v, ]))[v_other], decreasing = TRUE)) } - # Complicated case: we need to rely on correlation based heuristic - r_sq <- function(s, x) { - suppressWarnings(stats::cor(s, data.matrix(x), use = "p")^2) + # Complicated case: we need to rely on R-squared adjusted based heuristic + COL <- X[v_other] + if (isTRUE(color_numeric)) { + COL <- as.data.frame(data.matrix(COL)) } if (is.null(n_bins)) { n_bins <- ceiling(min(sqrt(nrow(X)), nrow(X) / 20)) } v_bin <- .fast_bin(X[[v]], n_bins = n_bins) s_bin <- split(S[, v], v_bin) - X_bin <- split(X[v_other], v_bin) - w <- do.call(rbind, lapply(X_bin, function(z) colSums(!is.na(z)))) - cor_squared <- do.call(rbind, mapply(r_sq, s_bin, X_bin, SIMPLIFY = FALSE)) - sort(colSums(w * cor_squared, na.rm = TRUE) / colSums(w), decreasing = TRUE) + COL_bin <- split(COL, v_bin) + w <- do.call(rbind, lapply(COL_bin, function(z) colSums(!is.na(z)))) + r2 <- do.call(rbind, mapply(r2_adj, s_bin, COL_bin, SIMPLIFY = FALSE)) + sort(colSums(w * r2, na.rm = TRUE) / colSums(w), decreasing = TRUE) } # Helper functions @@ -261,3 +265,40 @@ potential_interactions <- function(obj, v, n_bins = NULL) { q <- stats::quantile(z, seq(0, 1, length.out = n_bins + 1L), na.rm = TRUE) findInterval(z, unique(q), rightmost.closed = TRUE) } + +#' R-squared adjusted +#' +#' Internal function used to calculate the R-squared adjusted for a simple linear +#' regression with the SHAP values of the feature on the x-axis as response and +#' the (potential) color feature as single feature. +#' +#' @noRd +#' @keywords internal +#' +#' @param s Vector of within-bin SHAP values of the feature on the x-axis. +#' @param color Feature values of the color feature. +#' @returns The R-squared adjusted from regressing `s` onto `color`. If the calculation +#' fails (e.g., too many factor levels in `color`), the function returns `NA`. +r2_adj_uni <- function(s, color) { + tryCatch( + summary(stats::lm(s ~ color))[["adj.r.squared"]], + error = function(e) return(NA) + ) +} + +#' R-squared adjusted for multiple v +#' +#' Internal function that calls `r2_adj_uni()` for multiple color features. +#' +#' @noRd +#' @keywords internal +#' +#' @param s Vector of SHAP values of the feature on the x-axis. +#' @param df Data frame of (multiple) color features. +#' @returns Vector of R-squared adjusted (one per column in `df`). For columns in `df` +#' where the calculations fail, the value is `NA`. +r2_adj <- function(s, df) { + suppressWarnings( + vapply(df, FUN = r2_adj_uni, FUN.VALUE = 1.0, s = s, USE.NAMES = FALSE) + ) +} diff --git a/tests/testthat/test-potential_interactions.R b/tests/testthat/test-potential_interactions.R index 03da32e..27b6d4d 100644 --- a/tests/testthat/test-potential_interactions.R +++ b/tests/testthat/test-potential_interactions.R @@ -20,7 +20,14 @@ test_that("n_bins has an effect for numeric v", { ) }) -# Now with SHAP interactions +test_that("color_numeric has an effect", { + p1 <- potential_interactions(x, "Sepal.Width", color_numeric = TRUE) + p2 <- potential_interactions(x, "Sepal.Width", color_numeric = FALSE) + num <- c("Petal.Width", "Petal.Length") + expect_equal(p1[num], p2[num]) + expect_false(p1["Species"] == p2["Species"]) +}) + test_that("potential_interactions respects true SHAP interactions", { xi <- shapviz(fit, X_pred = dtrain, X = iris[, -1L], interactions = TRUE) i1 <- potential_interactions(xi, "Species") @@ -28,3 +35,35 @@ test_that("potential_interactions respects true SHAP interactions", { expect_equal(i1, i2, tolerance = 1e-5) }) +test_that("r2_adj_uni() returns R-squared adjusted", { + fit_lm <- lm(Sepal.Length ~ Species, data = iris) + expect_equal( + r2_adj_uni(iris$Sepal.Length, iris$Species), + summary(fit_lm)[["adj.r.squared"]] + ) +}) + +test_that("r2_adj_uni() fails with NA", { + expect_equal(r2_adj_uni(0, 1:2), NA) +}) + +test_that("r2_adj() returns R-squared adjusted per column in df", { + fit_lm1 <- lm(Sepal.Length ~ Species, data = iris) + fit_lm2 <- lm(Sepal.Length ~ Sepal.Width, data = iris) + + expect_equal( + r2_adj(iris$Sepal.Length, iris[c("Species", "Sepal.Width")]), + c(summary(fit_lm1)[["adj.r.squared"]], summary(fit_lm2)[["adj.r.squared"]]) + ) +}) + +test_that("r2_adj() can fail with NA", { + expect_equal(r2_adj(0, df = data.frame(x = c(1, 1), y = 1:2)), c(NA_real_, NA_real_)) +}) + + + +# r_sq <- function(s, x) { +# suppressWarnings(stats::cor(s, data.matrix(x), use = "p")^2) +# } +