Skip to content

Commit

Permalink
Handle single-row predictions for XGBoost
Browse files Browse the repository at this point in the history
  • Loading branch information
mayer79 committed Aug 13, 2024
1 parent bc368cd commit 3ab0048
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 1 deletion.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

- New argument `sort_features = TRUE` in `sv_importance()` and `sv_interaction()`. Set to `FALSE` to show the features as they appear in your SHAP matrix. In that case, the plots will show the *first* `max_display` features, not the *most important* features. Implements #136.

### Bug fixes

- `shapviz.xgboost()` would fail if a single row is passed. This has been fixed. Thanks, @sebsilas, for reporting.

# shapviz 0.9.3

## `sv_dependence()`: Control over automatic color feature selection
Expand Down
30 changes: 29 additions & 1 deletion R/shapviz.R
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,25 @@ shapviz.xgb.Booster = function(object, X_pred, X = X_pred, which_class = NULL,
S <- stats::predict(object, newdata = X_pred, predcontrib = TRUE, ...)

if (interactions) {
S_inter <- stats::predict(object, newdata = X_pred, predinteraction = TRUE, ...)
S_inter <- stats::predict(
object, newdata = X_pred, predinteraction = TRUE, ...
)
}

# Handle problem that S and S_inter lack a dimension if X_pred has only one row
# This might be fixed later directly in XGBoost.
if (nrow(X_pred) == 1L) {
if (is.list(S)) { # multiclass
S <- lapply(S, rbind)
if (interactions) {
S_inter <- lapply(S_inter, .add_dim)
}
} else {
S <- rbind(S)
if (interactions) {
S_inter <-.add_dim(S_inter)
}
}
}

# Multiclass
Expand Down Expand Up @@ -538,3 +556,13 @@ mshapviz <- function(object, ...) {
)
}
}

# Turns matrix into 3D-array with one "row". solving a problem with XGBoost and one row.
.add_dim <- function(x) {
if (is.matrix(x)) { # Problematic case: interactions is not 3D array
out <- array(dim = c(1L, dim(x)), dimnames = c(list(NULL), dimnames(x)))
out[1L, , ] <- x
}
return(out)
}

11 changes: 11 additions & 0 deletions tests/testthat/test-interface.R
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,14 @@ test_that("combining shapviz on classes 1, 2, 3 equal mshapviz", {
expect_equal(mshp, mshapviz(list(Class_1 = shp1, Class_2 = shp2, Class_3 = shp3)))
})

test_that("single row predictions work with shapviz", {
expect_no_error(
shp1 <- shapviz(fit, X_pred = X_pred[1L, , drop = FALSE], interactions = TRUE)
)
shp2 <- shapviz(fit, X_pred = X_pred, interactions = TRUE)
for (j in names(shp1)) {
expect_equal(shp1[[j]], shp2[[j]][1L, ])
}
})


0 comments on commit 3ab0048

Please sign in to comment.