Skip to content

Commit

Permalink
Merge pull request #144 from ModelOriented/xgb-2
Browse files Browse the repository at this point in the history
XGBoost 2 is approaching
  • Loading branch information
mayer79 authored Aug 19, 2024
2 parents bda883b + e1fdf15 commit e6559f3
Show file tree
Hide file tree
Showing 11 changed files with 940 additions and 426 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Depends:
R (>= 3.6.0)
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
Imports:
ggfittext (>= 0.8.0),
gggenes,
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# shapviz 0.9.4

## API improvements

- Support both XGBoost 1.x.x as well as XGBoost 2.x.x, implemented in #144.

## Improvements

- New argument `sort_features = TRUE` in `sv_importance()` and `sv_interaction()`. Set to `FALSE` to show the features as they appear in your SHAP matrix. In that case, the plots will show the *first* `max_display` features, not the *most important* features. Implements #136.
Expand Down
48 changes: 29 additions & 19 deletions R/shapviz.R
Original file line number Diff line number Diff line change
Expand Up @@ -201,36 +201,46 @@ shapviz.xgb.Booster = function(object, X_pred, X = X_pred, which_class = NULL,
)
}

# Handle problem that S and S_inter lack a dimension if X_pred has only one row
# This might be fixed later directly in XGBoost.
if (nrow(X_pred) == 1L) {
if (is.list(S)) { # multiclass
S <- lapply(S, rbind)
if (utils::packageVersion("xgboost") >= "2") {
# Turn result of multi-output model into list of lower dim arrays
if (length(dim(S)) == 3L) {
S <- asplit(S, MARGIN = 2L)
if (interactions) {
S_inter <- lapply(S_inter, .add_dim)
S_inter <- asplit(S_inter, MARGIN = 2L)
}
} else {
S <- rbind(S)
if (interactions) {
S_inter <-.add_dim(S_inter)
}
} else {
# Handle problem that S and S_inter lack a dimension if X_pred has only one row
# This only applies to XGBoost < 2
if (nrow(X_pred) == 1L) {
if (is.list(S)) { # multiclass
S <- lapply(S, rbind)
if (interactions) {
S_inter <- lapply(S_inter, .add_dim)
}
} else {
S <- rbind(S)
if (interactions) {
S_inter <-.add_dim(S_inter)
}
}
}
}

# Multiclass
# Multi-class (or some other multi-output situation)
if (is.list(S)) {
if (is.null(which_class)) {
nms <- setdiff(colnames(S[[1L]]), "BIAS")
pp <- ncol(S[[1L]]) # = ncol(X_pred) + 1. The last column is the baseline
if (interactions) {
S_inter <- lapply(S_inter, function(s) s[, nms, nms, drop = FALSE])
S_inter <- lapply(S_inter, function(s) s[, -pp, -pp, drop = FALSE])
} else {
# mapply() does not want to see a length 0 object like NULL
S_inter <- replicate(length(S), NULL)
}
shapviz_list <- mapply(
FUN = shapviz.matrix,
object = lapply(S, function(s) s[, nms, drop = FALSE]),
baseline = lapply(S, function(s) unname(s[1L, "BIAS"])),
object = lapply(S, function(s) s[, -pp, drop = FALSE]),
baseline = lapply(S, function(s) unname(s[1L, pp])),
S_inter = S_inter,
MoreArgs = list(X = X, collapse = collapse),
SIMPLIFY = FALSE
Expand All @@ -246,12 +256,12 @@ shapviz.xgb.Booster = function(object, X_pred, X = X_pred, which_class = NULL,
}

# Call matrix method
nms <- setdiff(colnames(S), "BIAS")
pp <- ncol(S)
shapviz.matrix(
object = S[, nms, drop = FALSE],
object = S[, -pp, drop = FALSE],
X = X,
baseline = unname(S[1L, "BIAS"]),
S_inter = if (interactions) S_inter[, nms, nms, drop = FALSE],
baseline = unname(S[1L, pp]),
S_inter = if (interactions) S_inter[, -pp, -pp, drop = FALSE],
collapse = collapse
)
}
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ library(shapviz)
library(ggplot2)
library(xgboost)

set.seed(1)
set.seed(10)

# Build model
x <- c("carat", "cut", "color", "clarity")
Expand Down
Binary file modified man/figures/README-dep.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
378 changes: 315 additions & 63 deletions man/figures/README-force.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
560 changes: 274 additions & 286 deletions man/figures/README-imp.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
370 changes: 315 additions & 55 deletions man/figures/README-waterfall.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/VIGNETTE-dep-ranger.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/VIGNETTE-dep.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion vignettes/multiple_output.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ sv_importance(shp)

sv_dependence(shp, v = "Sepal.Width") +
plot_layout(ncol = 2) &
ylim(-0.03, 0.035)
ylim(-0.06, 0.06)
```

![](../man/figures/VIGNETTE-dep-ranger.png)
Expand Down

0 comments on commit e6559f3

Please sign in to comment.