From 3ab0048ca178d10d705a2617d3c908a06da34090 Mon Sep 17 00:00:00 2001 From: Michael Mayer Date: Tue, 13 Aug 2024 18:57:23 +0200 Subject: [PATCH] Handle single-row predictions for XGBoost --- NEWS.md | 4 ++++ R/shapviz.R | 30 +++++++++++++++++++++++++++++- tests/testthat/test-interface.R | 11 +++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index 4d07f96..595ebf4 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,6 +4,10 @@ - New argument `sort_features = TRUE` in `sv_importance()` and `sv_interaction()`. Set to `FALSE` to show the features as they appear in your SHAP matrix. In that case, the plots will show the *first* `max_display` features, not the *most important* features. Implements #136. +### Bug fixes + +- `shapviz.xgboost()` would fail if a single row is passed. This has been fixed. Thanks, @sebsilas, for reporting. + # shapviz 0.9.3 ## `sv_dependence()`: Control over automatic color feature selection diff --git a/R/shapviz.R b/R/shapviz.R index 3d07af5..d7f087c 100644 --- a/R/shapviz.R +++ b/R/shapviz.R @@ -196,7 +196,25 @@ shapviz.xgb.Booster = function(object, X_pred, X = X_pred, which_class = NULL, S <- stats::predict(object, newdata = X_pred, predcontrib = TRUE, ...) if (interactions) { - S_inter <- stats::predict(object, newdata = X_pred, predinteraction = TRUE, ...) + S_inter <- stats::predict( + object, newdata = X_pred, predinteraction = TRUE, ... + ) + } + + # Handle problem that S and S_inter lack a dimension if X_pred has only one row + # This might be fixed later directly in XGBoost. + if (nrow(X_pred) == 1L) { + if (is.list(S)) { # multiclass + S <- lapply(S, rbind) + if (interactions) { + S_inter <- lapply(S_inter, .add_dim) + } + } else { + S <- rbind(S) + if (interactions) { + S_inter <-.add_dim(S_inter) + } + } } # Multiclass @@ -538,3 +556,13 @@ mshapviz <- function(object, ...) { ) } } + +# Turns matrix into 3D-array with one "row". solving a problem with XGBoost and one row. +.add_dim <- function(x) { + if (is.matrix(x)) { # Problematic case: interactions is not 3D array + out <- array(dim = c(1L, dim(x)), dimnames = c(list(NULL), dimnames(x))) + out[1L, , ] <- x + } + return(out) +} + diff --git a/tests/testthat/test-interface.R b/tests/testthat/test-interface.R index 3435ee0..6ad45fb 100644 --- a/tests/testthat/test-interface.R +++ b/tests/testthat/test-interface.R @@ -288,3 +288,14 @@ test_that("combining shapviz on classes 1, 2, 3 equal mshapviz", { expect_equal(mshp, mshapviz(list(Class_1 = shp1, Class_2 = shp2, Class_3 = shp3))) }) +test_that("single row predictions work with shapviz", { + expect_no_error( + shp1 <- shapviz(fit, X_pred = X_pred[1L, , drop = FALSE], interactions = TRUE) + ) + shp2 <- shapviz(fit, X_pred = X_pred, interactions = TRUE) + for (j in names(shp1)) { + expect_equal(shp1[[j]], shp2[[j]][1L, ]) + } +}) + +