Skip to content

Commit

Permalink
Merge pull request #125 from ModelOriented/interaction-heuristic
Browse files Browse the repository at this point in the history
Interaction heuristic arguments (WIP, do not merge)
  • Loading branch information
mayer79 authored Jan 1, 2024
2 parents c2cdb30 + 785d9ee commit a5f2fac
Show file tree
Hide file tree
Showing 10 changed files with 389 additions and 84 deletions.
38 changes: 36 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,49 @@
# shapviz 0.9.3

## User-visible changes
## `sv_dependence()`: Control over automatic color feature selection

### How is the color feature selected anyway?

If no SHAP interaction values are available, by default, the color feature `v'` is selected by the heuristic `potential_interaction()`, which works as follows:

1. If the feature `v` (the on the x-axis) is numeric, it is binned into `nbins` bins.
2. Per bin, the SHAP values of `v` are regressed onto `v` and the R-squared is calculated.
3. The R-squared are averaged over bins, weighted by the bin size.

This measures how much variability in the SHAP values is explained by `v'`, after accounting for `v`.

We have introduced four parameters to control the heuristic. Their defaults are in line with the old behaviour.

- `nbin = NULL`: Into how many quantile bins should a numeric `v` be binned? The default `NULL` equals the smaller of $n/20$ and $\sqrt n$ (rounded up), where $n$ is the sample size.
- `color_num` Should color features be converted to numeric, even if they are factors/characters? Default is `TRUE`.
- `scale = FALSE`: Should R-squared be multiplied with the sample variance of
within-bin SHAP values? If `TRUE`, bins with stronger vertical scatter will get higher weight. The default is `FALSE`.
- `adjusted = FALSE`: Should *adjusted* R-squared be calculated?

If SHAP interaction values are available, these parameters have no effect. In `sv_dependence()` they are called `ih_nbin` etc.

This partly implements the ideas in [#119](https://github.com/ModelOriented/shapviz/issues/119) of Roel Verbelen, thanks a lot for your patient explanations!

### Further plans?

We will continue to experiment with the defaults, which might change in the future. A good alternative to the current (naive) defaults could be:

- `nbins = 7`: Smaller than now to not overfit too strongly with factor/character color features.
- `color_num = FALSE`: To not naively integer encode factors/characters.
- `scale = TRUE`: To account for non-equal spread in bins.
- `adjusted = TRUE`: To not put too much weight on factors with many categories.

## Other user-visible changes

- `sv_dependence()`: If `color_var = "auto"` (default) and no color feature seems to be relevant (SHAP interaction is `NULL`, or heuristic returns no positive value), there won't be any color scale.
- `mshapviz()` objects can now be rowbinded via `rbind()` or `+`. Implemented by [@jmaspons](https://github.com/jmaspons) in [#110](https://github.com/ModelOriented/shapviz/pull/110).
- `mshapviz()` is more strict when combining multiple "shapviz" objects. These now need to have identical column names, see [#114](https://github.com/ModelOriented/shapviz/pull/114).

## Small changes

- `print.shapviz()` now shows top two rows of SHAP matrix.
- Re-activate all unit tests.
- Setting `nthread = 1` in all calls to `xgb.DMatrix()` as suggested by [@jmaspons](https://github.com/jmaspons) in [issue #109](https://github.com/ModelOriented/shapviz/issues/109).
- Setting `nthread = 1` in all calls to `xgb.DMatrix()` as suggested by [@jmaspons](https://github.com/jmaspons) in [#109](https://github.com/ModelOriented/shapviz/issues/109).
- Added "How to contribute" to README.
- `permshap()` connector is now part of {kerneshap} [#122](https://github.com/ModelOriented/shapviz/pull/122).

Expand Down
153 changes: 153 additions & 0 deletions R/potential_interactions.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
#' Interaction Strength
#'
#' Returns vector of interaction strengths between variable `v` and all other variables,
#' see Details.
#'
#' If SHAP interaction values are available, the interaction strength
#' between feature `v` and another feature `v'` is measured by twice their
#' mean absolute SHAP interaction values.
#'
#' Otherwise, we use a heuristic calculated as follows to calculate interaction strength
#' between `v` and each other "color" feature `v':
#' 1. If `v` is numeric, it is binned into `nbins` bins.
#' 2. Per bin, the SHAP values of `v` are regressed onto `v`, and the R-squared
#' is calculated.
#' 3. The R-squared are averaged over bins, weighted by the bin size.
#'
#' Set `scale = TRUE` to multiply the R-squared by the within-bin variance
#' of the SHAP values. This will put higher weight to bins with larger scatter.
#'
#' Set `color_num = FALSE` to *not* turn the values of the "color" feature `v'`
#' to numeric.
#'
#' Finally, set `adjusted = TRUE` to use *adjusted* R-squared.
#'
#' @param obj An object of class "shapviz".
#' @param v Variable name to calculate potential SHAP interactions for.
#' @param nbins Into how many quantile bins should a numeric `v` be binned?
#' The default `NULL` equals the smaller of \eqn{n/20} and \eqn{\sqrt n} (rounded up),
#' where \eqn{n} is the sample size. Ignored if `obj` contains SHAP interactions.
#' @param color_num Should other ("color") features `v'` be converted to numeric,
#' even if they are factors/characters? Default is `TRUE`.
#' Ignored if `obj` contains SHAP interactions.
#' @param scale Should adjusted R-squared be multiplied with the sample variance of
#' within-bin SHAP values? If `TRUE`, bins with stronger vertical scatter will get
#' higher weight. The default is `FALSE`. Ignored if `obj` contains SHAP interactions.
#' @param adjusted Should *adjusted* R-squared be used? Default is `FALSE`.
#' @returns A named vector of decreasing interaction strengths.
#' @export
#' @seealso [sv_dependence()]
potential_interactions <- function(obj, v, nbins = NULL, color_num = TRUE,
scale = FALSE, adjusted = FALSE) {
stopifnot(is.shapviz(obj))
S <- get_shap_values(obj)
S_inter <- get_shap_interactions(obj)
X <- get_feature_values(obj)
nms <- colnames(obj)
v_other <- setdiff(nms, v)
stopifnot(v %in% nms)
if (ncol(obj) <= 1L) {
return(NULL)
}

# Simple case: we have SHAP interaction values
if (!is.null(S_inter)) {
return(sort(2 * colMeans(abs(S_inter[, v, ]))[v_other], decreasing = TRUE))
}

# Complicated case: calculate heuristic per color variable
if (is.null(nbins)) {
nbins <- ceiling(min(sqrt(nrow(X)), nrow(X) / 20))
}
out <- vapply(
X[v_other], # data.frame is a list
FUN = heuristic,
FUN.VALUE = 1.0,
s = S[, v],
bins = .fast_bin(X[[v]], nbins = nbins),
color_num = color_num,
scale = scale,
adjusted = adjusted
)
sort(out, decreasing = TRUE, na.last = TRUE)
}

#' Interaction Heuristic
#'
#' Internal function used to calculate the interaction heuristics described in
#' [potential_interactions()]. It calls `heuristic_in_bin()` per bin and aggregates
#' the result.
#'
#' @noRd
#' @keywords internal
#'
#' @inheritParams potential_interactions
#' @param color Feature values of the "color" feature.
#' @param s SHAP values of `v`.
#' @returns A single number.
heuristic <- function(color, s, bins, color_num, scale, adjusted) {
if (isTRUE(color_num)) {
color <- .as_numeric(color)
}
color <- split(color, bins)
s <- split(s, bins)
M <- mapply(
heuristic_in_bin,
color = color,
s = s,
MoreArgs = list(scale = scale, adjusted = adjusted)
)
stats::weighted.mean(M[1L, ], M[2L, ], na.rm = TRUE)
}

#' Interaction Heuristic in Bin
#'
#' Internal function used to calculate the within-bin heuristic used in `heuristic()`.
#' See `heuristic()` for details.
#'
#' @noRd
#' @keywords internal
#'
#' @inheritParams heuristic
#' @returns
#' A (1x2) matrix with heuristic and number of observations.
heuristic_in_bin <- function(color, s, scale = FALSE, adjusted = FALSE) {
suppressWarnings(
tryCatch(
{
z <- stats::lm(s ~ color)
r <- z$residuals
n <- length(r)
var_y <- stats::var(z$fitted.values + r)
denom <- if (adjusted) z$df.residual else n - 1
var_r <- sum(r^2) / denom
stat <- 1 - var_r / var_y
if (scale) {
stat <- stat * var_y
}
cbind(stat = stat, n = n)
},
error = function(e) return(cbind(stat = NA, n = 0))
)
)
}

# Like as.numeric(), but can deal with factor variables
.as_numeric <- function(z) {
if (is.numeric(z)) {
return(z)
}
if (is.character(z)) {
z <- factor(z)
}
as.numeric(z)
}

# Bins discrete z into integer valued bins
.fast_bin <- function(z, nbins) {
if (.is_discrete(z, n_unique = nbins)) {
return(z)
}
q <- stats::quantile(z, seq(0, 1, length.out = nbins + 1L), na.rm = TRUE)
findInterval(z, unique(q), rightmost.closed = TRUE)
}
99 changes: 33 additions & 66 deletions R/sv_dependence.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#' Scatterplot of the SHAP values of a feature against its feature values.
#' If SHAP interaction values are available, setting `interactions = TRUE` allows
#' to focus on pure interaction effects (multiplied by two) or on pure main effects.
#' By default, the feature on the color scale is selected via SHAP interactions
#' (if available) or an interaction heuristic, see [potential_interactions()].
#'
#' @importFrom rlang .data
#'
Expand Down Expand Up @@ -31,6 +33,9 @@
#' Requires SHAP interaction values. If `color_var = NULL` (or it is equal to `v`),
#' the pure main effect of `v` is visualized. Otherwise, twice the SHAP interaction
#' values between `v` and the `color_var` are plotted.
#' @param ih_nbins,ih_color_num,ih_scale,ih_adjusted Interaction heuristic (ih)
#' parameters used to select the color variable, see [potential_interactions()].
#' Only used if `color_var = "auto"` and if there are no SHAP interaction values.
#' @param ... Arguments passed to [ggplot2::geom_jitter()].
#' @returns An object of class "ggplot" (or "patchwork") representing a dependence plot.
#' @examples
Expand Down Expand Up @@ -71,7 +76,9 @@ sv_dependence.default <- function(object, ...) {
#' @export
sv_dependence.shapviz <- function(object, v, color_var = "auto", color = "#3b528b",
viridis_args = getOption("shapviz.viridis_args"),
jitter_width = NULL, interactions = FALSE, ...) {
jitter_width = NULL, interactions = FALSE,
ih_nbins = NULL, ih_color_num = TRUE,
ih_scale = FALSE, ih_adjusted = FALSE, ...) {
p <- length(v)
if (p > 1L || length(color_var) > 1L) {
if (is.null(color_var)) {
Expand All @@ -90,6 +97,10 @@ sv_dependence.shapviz <- function(object, v, color_var = "auto", color = "#3b528
object = object,
viridis_args = viridis_args,
interactions = interactions,
ih_nbins = ih_nbins,
ih_color_num = ih_color_num,
ih_scale = ih_scale,
ih_adjusted = ih_adjusted,
...
),
SIMPLIFY = FALSE
Expand All @@ -116,10 +127,20 @@ sv_dependence.shapviz <- function(object, v, color_var = "auto", color = "#3b528
jitter_width <- 0.2 * .is_discrete(X[[v]], n_unique = 7L)
}

# Set color value
# Set color value if "auto"
if (!is.null(color_var) && color_var == "auto" && !("auto" %in% nms)) {
scores <- potential_interactions(object, v)
color_var <- names(scores)[1L] # NULL if p = 1L
scores <- potential_interactions(
object,
v,
nbins = ih_nbins,
color_num = ih_color_num,
scale = ih_scale,
adjusted = ih_adjusted
)
# 'scores' can be NULL, or a numeric vector like c(0.1, 0, -0.01, NaN, NA)
# Thus, let's take the first positive one (or none)
scores <- scores[!is.na(scores) & scores > 0] # NULL stays NULL
color_var <- if (length(scores) >= 1L) names(scores)[1L]
}
if (isTRUE(interactions)) {
if (is.null(color_var)) {
Expand Down Expand Up @@ -169,7 +190,9 @@ sv_dependence.shapviz <- function(object, v, color_var = "auto", color = "#3b528
#' @export
sv_dependence.mshapviz <- function(object, v, color_var = "auto", color = "#3b528b",
viridis_args = getOption("shapviz.viridis_args"),
jitter_width = NULL, interactions = FALSE, ...) {
jitter_width = NULL, interactions = FALSE,
ih_nbins = NULL, ih_color_num = TRUE,
ih_scale = FALSE, ih_adjusted = FALSE, ...) {
stopifnot(
length(v) == 1L,
length(color_var) <= 1L
Expand All @@ -184,75 +207,19 @@ sv_dependence.mshapviz <- function(object, v, color_var = "auto", color = "#3b52
viridis_args = viridis_args,
jitter_width = jitter_width,
interactions = interactions,
ih_nbins = ih_nbins,
ih_color_num = ih_color_num,
ih_scale = ih_scale,
ih_adjusted = ih_adjusted,
...
)
plot_list <- add_titles(plot_list, nms = names(object)) # see sv_waterfall()
patchwork::wrap_plots(plot_list)
}

#' Interaction Strength
#'
#' Returns vector of interaction strengths between variable `v` and all other variables.
#'
#' If SHAP interaction values are available, interaction strength
#' between feature `v` and another feature `v'` is measured by twice their
#' mean absolute SHAP interaction values. Otherwise, we use as heuristic the
#' squared correlation between feature values of `v'` and
#' SHAP values of `v`, averaged over (binned) values of `v`.
#' A numeric `v` with more than `n_bins` unique values is binned into quantile bins.
#' Currently `n_bins` equals the smaller of \eqn{n/20} and \eqn{\sqrt n}, where \eqn{n}
#' is the sample size.
#' The average squared correlation is weighted by the number of non-missing feature
#' values in the bin. Note that non-numeric color features are turned to numeric
#' by calling [data.matrix()], which does not necessarily make sense.
#'
#' @param obj An object of class "shapviz".
#' @param v Variable name.
#' @returns A named vector of decreasing interaction strengths.
#' @export
#' @seealso [sv_dependence()]
potential_interactions <- function(obj, v) {
stopifnot(is.shapviz(obj))
S <- get_shap_values(obj)
S_inter <- get_shap_interactions(obj)
X <- get_feature_values(obj)
nms <- colnames(obj)
v_other <- setdiff(nms, v)
stopifnot(v %in% nms)

if (ncol(obj) <= 1L) {
return(NULL)
}

# Simple case: we have SHAP interaction values
if (!is.null(S_inter)) {
return(sort(2 * colMeans(abs(S_inter[, v, ]))[v_other], decreasing = TRUE))
}

# Complicated case: we need to rely on correlation based heuristic
r_sq <- function(s, x) {
suppressWarnings(stats::cor(s, data.matrix(x), use = "p")^2)
}
n_bins <- ceiling(min(sqrt(nrow(X)), nrow(X) / 20))
v_bin <- .fast_bin(X[[v]], n_bins = n_bins)
s_bin <- split(S[, v], v_bin)
X_bin <- split(X[v_other], v_bin)
w <- do.call(rbind, lapply(X_bin, function(z) colSums(!is.na(z))))
cor_squared <- do.call(rbind, mapply(r_sq, s_bin, X_bin, SIMPLIFY = FALSE))
sort(colSums(w * cor_squared, na.rm = TRUE) / colSums(w), decreasing = TRUE)
}

# Helper functions

# Checks if z is discrete
.is_discrete <- function(z, n_unique) {
is.factor(z) || is.character(z) || is.logical(z) || (length(unique(z)) <= n_unique)
}

# Bins z into integer valued bins, but only if discrete
.fast_bin <- function(z, n_bins) {
if (.is_discrete(z, n_unique = n_bins)) {
return(z)
}
q <- stats::quantile(z, seq(0, 1, length.out = n_bins + 1L), na.rm = TRUE)
findInterval(z, unique(q), rightmost.closed = TRUE)
}
Binary file modified man/figures/VIGNETTE-dep-ranger.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/VIGNETTE-dep.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit a5f2fac

Please sign in to comment.