Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Catch some edge-case in heuristic_in_bin() #126

Merged
merged 2 commits into from
Jan 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## `sv_dependence()`: Control over automatic color feature selection

### How is the color feature selected anyway?
### How is the color feature selected, anyway?

If no SHAP interaction values are available, by default, the color feature `v'` is selected by the heuristic `potential_interaction()`, which works as follows:

Expand Down
42 changes: 20 additions & 22 deletions R/potential_interactions.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,10 @@ heuristic <- function(color, s, bins, color_num, scale, adjusted) {
if (isTRUE(color_num)) {
color <- .as_numeric(color)
}
color <- split(color, bins)
s <- split(s, bins)
M <- mapply(
heuristic_in_bin,
color = color,
s = s,
color = split(color, bins),
s = split(s, bins),
MoreArgs = list(scale = scale, adjusted = adjusted)
)
stats::weighted.mean(M[1L, ], M[2L, ], na.rm = TRUE)
Expand All @@ -112,24 +110,24 @@ heuristic <- function(color, s, bins, color_num, scale, adjusted) {
#' @returns
#' A (1x2) matrix with heuristic and number of observations.
heuristic_in_bin <- function(color, s, scale = FALSE, adjusted = FALSE) {
suppressWarnings(
tryCatch(
{
z <- stats::lm(s ~ color)
r <- z$residuals
n <- length(r)
var_y <- stats::var(z$fitted.values + r)
denom <- if (adjusted) z$df.residual else n - 1
var_r <- sum(r^2) / denom
stat <- 1 - var_r / var_y
if (scale) {
stat <- stat * var_y
}
cbind(stat = stat, n = n)
},
error = function(e) return(cbind(stat = NA, n = 0))
)
)
ok <- !is.na(color)
color <- color[ok]
s <- s[ok]
n <- length(s)
var_s <- stats::var(s)
if (n < 2L || var_s < .Machine$double.eps || length(unique(color)) < 2L) {
return(cbind(stat = NA, n = n))
}
z <- stats::lm(s ~ color)
var_r <- sum(z$residuals^2) / (if (adjusted) z$df.residual else n - 1)
stat <- 1 - var_r / var_s
if (scale) {
stat <- stat * var_s
}
if (!is.finite(stat)) {
stat <- NA
}
cbind(stat = stat, n = n)
}

# Like as.numeric(), but can deal with factor variables
Expand Down
4 changes: 2 additions & 2 deletions R/sv_dependence.R
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ sv_dependence.shapviz <- function(object, v, color_var = "auto", color = "#3b528
scale = ih_scale,
adjusted = ih_adjusted
)
# 'scores' can be NULL, or a numeric vector like c(0.1, 0, -0.01, NaN, NA)
# Thus, let's take the first positive one (or none)
# 'scores' can be NULL, or a sorted vector like c(0.1, 0, -0.01, NA)
# Thus, let's take the first positive one (or NULL)
scores <- scores[!is.na(scores) & scores > 0] # NULL stays NULL
color_var <- if (length(scores) >= 1L) names(scores)[1L]
}
Expand Down
Binary file removed man/figures/VIGNETTE-imp.png
Binary file not shown.
99 changes: 98 additions & 1 deletion tests/testthat/test-potential_interactions.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,104 @@ test_that("heuristic_in_bin() returns R-squared", {
})

test_that("Failing heuristic_in_bin() returns NA", {
expect_equal(heuristic_in_bin(0, 1:2), cbind(stat = NA, n = 0))
expect_equal(heuristic_in_bin(c(NA, NA), 1:2), cbind(stat = NA, n = 0))
})

test_that("heuristic_in_bin() returns NA for constant response", {
expect_equal(
heuristic_in_bin(color = 1:3, s = c(1, 1, 1)),
cbind(stat = NA, n = 3L)
)
expect_equal(
heuristic_in_bin(color = 1:3, s = c(1, 1, 1), scale = TRUE),
cbind(stat = NA, n = 3L)
)
expect_equal(
heuristic_in_bin(color = 1:3, s = c(1, 1, 1), adjust = TRUE),
cbind(stat = NA, n = 3L)
)
expect_equal(
heuristic_in_bin(color = 1:3, s = c(1, 1, 1), adjust = TRUE, scale = TRUE),
cbind(stat = NA, n = 3L)
)
})

test_that("heuristic_in_bin() returns NA for constant color", {
expect_equal(
heuristic_in_bin(s = 1:3, color = c(1, 1, 1)),
cbind(stat = NA, n = 3L)
)
expect_equal(
heuristic_in_bin(s = 1:3, color = c(1, 1, 1), scale = TRUE),
cbind(stat = NA, n = 3L)
)
expect_equal(
heuristic_in_bin(s = 1:3, color = c(1, 1, 1), adjust = TRUE),
cbind(stat = NA, n = 3L)
)
expect_equal(
heuristic_in_bin(s = 1:3, color = c(1, 1, 1), adjust = TRUE, scale = TRUE),
cbind(stat = NA, n = 3L)
)
})

test_that("heuristic_in_bin() returns 0 if response and color are constant", {
z <- c(1, 1)
expect_equal(
heuristic_in_bin(color = z, s = z),
cbind(stat = NA, n = 2L)
)
expect_equal(
heuristic_in_bin(color = z, s = z, scale = TRUE),
cbind(stat = NA, n = 2L)
)
expect_equal(
heuristic_in_bin(color = z, s = z, adjust = TRUE),
cbind(stat = NA, n = 2L)
)
expect_equal(
heuristic_in_bin(color = z, s = z, adjust = TRUE, scale = TRUE),
cbind(stat = NA, n = 2L)
)
})

test_that("heuristic_in_bin() returns NA for single obs", {
expect_equal(
heuristic_in_bin(color = 2, s = 2),
cbind(stat = NA, n = 1L)
)
expect_equal(
heuristic_in_bin(color = 2, s = 2, scale = TRUE),
cbind(stat = NA, n = 1L)
)
expect_equal(
heuristic_in_bin(color = 2, s = 2, adjust = TRUE),
cbind(stat = NA, n = 1L)
)
expect_equal(
heuristic_in_bin(color = 2, s = 2, adjust = TRUE, scale = TRUE),
cbind(stat = NA, n = 1L)
)
})

test_that("heuristic_in_bin() returns NA for single obs", {
cc <- factor(LETTERS[1:3])
expect_equal(
heuristic_in_bin(color = cc, s = 1:3),
cbind(stat = 1, n = 3L)
)
expect_equal(
heuristic_in_bin(color = cc, s = 2*(1:3), scale = TRUE),
cbind(stat = 4, n = 3L)
)
expect_equal(
heuristic_in_bin(color = cc, s = 1:3, adjust = TRUE),
cbind(stat = NA, n = 3L)
)
expect_equal(
heuristic_in_bin(color = cc, s = 2*(1:3), adjust = TRUE, scale = TRUE),
cbind(stat = NA, n = 3L)
)
})

test_that("heuristic() returns average R-squared", {
Expand Down
1 change: 1 addition & 0 deletions vignettes/basic_use.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ The above example uses XGBoost to calculate SHAP values. In the following sectio

```r
library(lightgbm)

dtrain <- lgb.Dataset(data.matrix(diamonds[x]), label = diamonds$price)

fit <- lgb.train(
Expand Down