Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sv_importance and sv_interaction receive a sort_features option. #137

Merged
merged 1 commit into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: shapviz
Title: SHAP Visualizations
Version: 0.9.3
Version: 0.9.4
Authors@R: c(
person("Michael", "Mayer", , "[email protected]", role = c("aut", "cre")),
person("Adrian", "Stando", , "[email protected]", role = "ctb")
Expand All @@ -21,7 +21,7 @@ Depends:
R (>= 3.6.0)
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
Imports:
ggfittext (>= 0.8.0),
gggenes,
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# shapviz 0.9.4

## Improvements

- New argument `sort_features = TRUE` in `sv_importance()` and `sv_interaction()`. Set to `FALSE` to show the features as they appear in your SHAP matrix. In that case, the plots will show the *first* `max_display` features, not the *most important* features. Implements #136.

# shapviz 0.9.3

## `sv_dependence()`: Control over automatic color feature selection
Expand Down
25 changes: 17 additions & 8 deletions R/sv_importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#' @param kind Should a "bar" plot (the default), a "beeswarm" plot, or "both" be shown?
#' Set to "no" in order to suppress plotting. In that case, the sorted
#' SHAP feature importances of all variables are returned.
#' @param max_display Maximum number of features (with highest importance) to plot.
#' @param max_display How many features should be plotted?
#' Set to `Inf` to show all features. Has no effect if `kind = "no"`.
#' @param fill Color used to fill the bars (only used if bars are shown).
#' @param bar_width Relative width of the bars (only used if bars are shown).
Expand All @@ -38,6 +38,7 @@
#' (only if `show_numbers = TRUE`). To change to scientific notation, use
#' `function(x) = prettyNum(x, scientific = TRUE)`.
#' @param number_size Text size of the numbers (if `show_numbers = TRUE`).
#' @param sort_features Should features be sorted or not? The default is `TRUE`.
#' @param ... Arguments passed to [ggplot2::geom_bar()] (if `kind = "bar"`) or to
#' [ggplot2::geom_point()] otherwise. For instance, passing `alpha = 0.2` will produce
#' semi-transparent beeswarms, and setting `size = 3` will produce larger dots.
Expand Down Expand Up @@ -75,10 +76,10 @@ sv_importance.shapviz <- function(object, kind = c("bar", "beeswarm", "both", "n
viridis_args = getOption("shapviz.viridis_args"),
color_bar_title = "Feature value",
show_numbers = FALSE, format_fun = format_max,
number_size = 3.2, ...) {
number_size = 3.2, sort_features = TRUE, ...) {
stopifnot("format_fun must be a function" = is.function(format_fun))
kind <- match.arg(kind)
imp <- .get_imp(get_shap_values(object))
imp <- .get_imp(get_shap_values(object), sort_features = sort_features)

if (kind == "no") {
return(imp)
Expand Down Expand Up @@ -162,13 +163,13 @@ sv_importance.mshapviz <- function(object, kind = c("bar", "beeswarm", "both", "
viridis_args = getOption("shapviz.viridis_args"),
color_bar_title = "Feature value",
show_numbers = FALSE, format_fun = format_max,
number_size = 3.2, ...) {
number_size = 3.2, sort_features = TRUE, ...) {
kind <- match.arg(kind)
bar_type <- match.arg(bar_type)

# All other cases are done via {patchwork}
if (kind %in% c("bar", "no") && bar_type != "separate") {
imp <- .get_imp(get_shap_values(object))
imp <- .get_imp(get_shap_values(object), sort_features = sort_features)
if (kind == "no") {
return(imp)
}
Expand Down Expand Up @@ -223,6 +224,7 @@ sv_importance.mshapviz <- function(object, kind = c("bar", "beeswarm", "both", "
show_numbers = show_numbers,
format_fun = format_fun,
number_size = number_size,
sort_features = sort_features,
...
)
if (kind == "no") {
Expand All @@ -243,13 +245,20 @@ sv_importance.mshapviz <- function(object, kind = c("bar", "beeswarm", "both", "
(z - r[1L]) /(r[2L] - r[1L])
}

.get_imp <- function(z) {
.get_imp <- function(z, sort_features = TRUE) {
if (is.matrix(z)) {
return(sort(colMeans(abs(z)), decreasing = TRUE))
imp <- colMeans(abs(z))
if (sort_features) {
imp <- sort(imp, decreasing = TRUE)
}
return(imp)
}
# list/mshapviz
imp <- sapply(z, function(x) colMeans(abs(x)))
imp[order(-rowSums(imp)), ]
if (sort_features) {
imp <- imp[order(-rowSums(imp)), ]
}
return(imp)
}

.scale_X <- function(X) {
Expand Down
9 changes: 6 additions & 3 deletions R/sv_interaction.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ sv_interaction.shapviz <- function(object, kind = c("beeswarm", "no"),
max_display = 7L, alpha = 0.3,
bee_width = 0.3, bee_adjust = 0.5,
viridis_args = getOption("shapviz.viridis_args"),
color_bar_title = "Row feature value", ...) {
color_bar_title = "Row feature value",
sort_features = TRUE, ...) {
kind <- match.arg(kind)
if (is.null(get_shap_interactions(object))) {
stop("No SHAP interaction values available.")
}
ord <- names(.get_imp(get_shap_values(object)))
ord <- names(.get_imp(get_shap_values(object), sort_features = sort_features))
object <- object[, ord]

if (kind == "no") {
Expand Down Expand Up @@ -112,7 +113,8 @@ sv_interaction.mshapviz <- function(object, kind = c("beeswarm", "no"),
max_display = 7L, alpha = 0.3,
bee_width = 0.3, bee_adjust = 0.5,
viridis_args = getOption("shapviz.viridis_args"),
color_bar_title = "Row feature value", ...) {
color_bar_title = "Row feature value",
sort_features = TRUE, ...) {
kind <- match.arg(kind)

plot_list <- lapply(
Expand All @@ -126,6 +128,7 @@ sv_interaction.mshapviz <- function(object, kind = c("beeswarm", "no"),
bee_adjust = bee_adjust,
viridis_args = viridis_args,
color_bar_title = color_bar_title,
sort_features = sort_features,
...
)
if (kind == "no") {
Expand Down
2 changes: 1 addition & 1 deletion man/shapviz-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion man/sv_importance.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion man/sv_interaction.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion packaging.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ library(usethis)
use_description(
fields = list(
Title = "SHAP Visualizations",
Version = "0.9.3",
Version = "0.9.4",
Description = "Visualizations for SHAP (SHapley Additive exPlanations),
such as waterfall plots, force plots, various types of importance plots,
dependence plots, and interaction plots.
Expand Down
23 changes: 17 additions & 6 deletions tests/testthat/test-plots-mshapviz.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,20 +108,31 @@ test_that("plots work for non-syntactic column names", {
)
})

test_that("sv_importance() and sv_interaction() and kind = 'no' gives matrix", {
X_pred <- data.matrix(iris[, -1L])
dtrain <- xgboost::xgb.DMatrix(X_pred, label = iris[, 1L], nthread = 1)
fit <- xgboost::xgb.train(params = list(nthread = 1L), data = dtrain, nrounds = 1L)
x <- shapviz(fit, X_pred = X_pred, interactions = TRUE)
x <- c(m1 = x, m2 = x)
X_pred <- data.matrix(iris[, -1L])
dtrain <- xgboost::xgb.DMatrix(X_pred, label = iris[, 1L], nthread = 1)
fit <- xgboost::xgb.train(params = list(nthread = 1L), data = dtrain, nrounds = 1L)
x <- shapviz(fit, X_pred = X_pred, interactions = TRUE)
x <- c(m1 = x, m2 = x)

test_that("sv_importance() and sv_interaction() and kind = 'no' gives matrix", {
imp <- sv_importance(x, kind = "no")
expect_true(is.matrix(imp) && all(dim(imp) == c(4L, length(x))))

inter <- sv_interaction(x, kind = "no")
expect_true(is.list(inter) && all(dim(inter[[1L]]) == rep(ncol(X_pred), 2L)))
})


test_that("sv_importance() and sv_interaction() respect sort_features = FALSE", {
imp <- sv_importance(x, kind = "no", sort_features = FALSE)
expect_true(all(rownames(imp) == colnames(x$m1)))

inter <- sv_interaction(x, kind = "no", sort_features = FALSE)
expect_true(all(rownames(inter$m1) == colnames(x$m1)))
})



test_that("sv_dependence() does not work with multiple v", {
X_pred <- data.matrix(iris[, -1L])
dtrain <- xgboost::xgb.DMatrix(X_pred, label = iris[, 1L], nthread = 1)
Expand Down
13 changes: 13 additions & 0 deletions tests/testthat/test-plots-shapviz.R
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,16 @@ test_that("sv_importance() and sv_interaction() and kind = 'no' gives numeric ou
expect_true(is.numeric(inter) && all(dim(inter) == rep(ncol(X_pred), 2L)))
})

test_that("sv_importance() and sv_interaction() respect sort_features = FALSE", {
X_pred <- data.matrix(iris[, -1L])
dtrain <- xgboost::xgb.DMatrix(X_pred, label = iris[, 1L], nthread = 1)
fit <- xgboost::xgb.train(params = list(nthread = 1L), data = dtrain, nrounds = 1L)
x <- shapviz(fit, X_pred = X_pred, interactions = TRUE)

imp <- sv_importance(x, kind = "no", sort_features = FALSE)
expect_true(all(names(imp) == colnames(x)))

inter <- sv_interaction(x, kind = "no", sort_features = FALSE)
expect_true(all(names(inter) == colnames(x)))
})

Loading