Skip to content

Commit

Permalink
Add argument color_numeric to potential_interactions()
Browse files Browse the repository at this point in the history
  • Loading branch information
mayer79 committed Dec 28, 2023
1 parent aa1d2d9 commit 14cfac1
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 9 deletions.
57 changes: 49 additions & 8 deletions R/sv_dependence.R
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,13 @@ sv_dependence.mshapviz <- function(object, v, color_var = "auto", color = "#3b52
#' @param n_bins A numeric `v` with more than `n_bins` unique values is binned
#' into that many quantile bins. If `NULL` (default), `n_bins` equals the smaller
#' of \eqn{n/20} and \eqn{\sqrt n} (rounded up), where \eqn{n} is the sample size.
#' Ignored if `obj` contains SHAP interactions.
#' @param color_numeric Should color feature values be converted to numeric? Default is
#' `TRUE`. Ignored if `obj` contains SHAP interactions.
#' @returns A named vector of decreasing interaction strengths.
#' @export
#' @seealso [sv_dependence()]
potential_interactions <- function(obj, v, n_bins = NULL) {
potential_interactions <- function(obj, v, n_bins = NULL, color_numeric = TRUE) {
stopifnot(
is.shapviz(obj),
is.null(n_bins) || (length(n_bins) == 1L && n_bins >= 1L)
Expand All @@ -232,19 +235,20 @@ potential_interactions <- function(obj, v, n_bins = NULL) {
return(sort(2 * colMeans(abs(S_inter[, v, ]))[v_other], decreasing = TRUE))
}

# Complicated case: we need to rely on correlation based heuristic
r_sq <- function(s, x) {
suppressWarnings(stats::cor(s, data.matrix(x), use = "p")^2)
# Complicated case: we need to rely on R-squared adjusted based heuristic
COL <- X[v_other]
if (isTRUE(color_numeric)) {
COL <- as.data.frame(data.matrix(COL))
}
if (is.null(n_bins)) {
n_bins <- ceiling(min(sqrt(nrow(X)), nrow(X) / 20))
}
v_bin <- .fast_bin(X[[v]], n_bins = n_bins)
s_bin <- split(S[, v], v_bin)
X_bin <- split(X[v_other], v_bin)
w <- do.call(rbind, lapply(X_bin, function(z) colSums(!is.na(z))))
cor_squared <- do.call(rbind, mapply(r_sq, s_bin, X_bin, SIMPLIFY = FALSE))
sort(colSums(w * cor_squared, na.rm = TRUE) / colSums(w), decreasing = TRUE)
COL_bin <- split(COL, v_bin)
w <- do.call(rbind, lapply(COL_bin, function(z) colSums(!is.na(z))))
r2 <- do.call(rbind, mapply(r2_adj, s_bin, COL_bin, SIMPLIFY = FALSE))
sort(colSums(w * r2, na.rm = TRUE) / colSums(w), decreasing = TRUE)
}

# Helper functions
Expand All @@ -261,3 +265,40 @@ potential_interactions <- function(obj, v, n_bins = NULL) {
q <- stats::quantile(z, seq(0, 1, length.out = n_bins + 1L), na.rm = TRUE)
findInterval(z, unique(q), rightmost.closed = TRUE)
}

#' R-squared adjusted
#'
#' Internal function used to calculate the R-squared adjusted for a simple linear
#' regression with the SHAP values of the feature on the x-axis as response and
#' the (potential) color feature as single feature.
#'
#' @noRd
#' @keywords internal
#'
#' @param s Vector of within-bin SHAP values of the feature on the x-axis.
#' @param color Feature values of the color feature.
#' @returns The R-squared adjusted from regressing `s` onto `color`. If the calculation
#' fails (e.g., too many factor levels in `color`), the function returns `NA`.
r2_adj_uni <- function(s, color) {
tryCatch(
summary(stats::lm(s ~ color))[["adj.r.squared"]],
error = function(e) return(NA)
)
}

#' R-squared adjusted for multiple v
#'
#' Internal function that calls `r2_adj_uni()` for multiple color features.
#'
#' @noRd
#' @keywords internal
#'
#' @param s Vector of SHAP values of the feature on the x-axis.
#' @param df Data frame of (multiple) color features.
#' @returns Vector of R-squared adjusted (one per column in `df`). For columns in `df`
#' where the calculations fail, the value is `NA`.
r2_adj <- function(s, df) {
suppressWarnings(
vapply(df, FUN = r2_adj_uni, FUN.VALUE = 1.0, s = s, USE.NAMES = FALSE)
)
}
41 changes: 40 additions & 1 deletion tests/testthat/test-potential_interactions.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,50 @@ test_that("n_bins has an effect for numeric v", {
)
})

# Now with SHAP interactions
test_that("color_numeric has an effect", {
p1 <- potential_interactions(x, "Sepal.Width", color_numeric = TRUE)
p2 <- potential_interactions(x, "Sepal.Width", color_numeric = FALSE)
num <- c("Petal.Width", "Petal.Length")
expect_equal(p1[num], p2[num])
expect_false(p1["Species"] == p2["Species"])
})

test_that("potential_interactions respects true SHAP interactions", {
xi <- shapviz(fit, X_pred = dtrain, X = iris[, -1L], interactions = TRUE)
i1 <- potential_interactions(xi, "Species")
i2 <- sv_interaction(xi, kind = "no")[names(i1), "Species"]
expect_equal(i1, i2, tolerance = 1e-5)
})

test_that("r2_adj_uni() returns R-squared adjusted", {
fit_lm <- lm(Sepal.Length ~ Species, data = iris)
expect_equal(
r2_adj_uni(iris$Sepal.Length, iris$Species),
summary(fit_lm)[["adj.r.squared"]]
)
})

test_that("r2_adj_uni() fails with NA", {
expect_equal(r2_adj_uni(0, 1:2), NA)
})

test_that("r2_adj() returns R-squared adjusted per column in df", {
fit_lm1 <- lm(Sepal.Length ~ Species, data = iris)
fit_lm2 <- lm(Sepal.Length ~ Sepal.Width, data = iris)

expect_equal(
r2_adj(iris$Sepal.Length, iris[c("Species", "Sepal.Width")]),
c(summary(fit_lm1)[["adj.r.squared"]], summary(fit_lm2)[["adj.r.squared"]])
)
})

test_that("r2_adj() can fail with NA", {
expect_equal(r2_adj(0, df = data.frame(x = c(1, 1), y = 1:2)), c(NA_real_, NA_real_))
})



# r_sq <- function(s, x) {
# suppressWarnings(stats::cor(s, data.matrix(x), use = "p")^2)
# }

0 comments on commit 14cfac1

Please sign in to comment.