Skip to content

Commit

Permalink
Merge pull request #154 from ModelOriented/update-vignettes
Browse files Browse the repository at this point in the history
Switch to log_carat/log_price in docu
  • Loading branch information
mayer79 authored Sep 11, 2024
2 parents ea85499 + 85349ef commit 92f8153
Show file tree
Hide file tree
Showing 17 changed files with 171 additions and 369 deletions.
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
### Documentation

- Add vignette for Tidymodels.
- Update "basic_use" vignette.
- Update vignettes.
- Update README.

# shapviz 0.9.4
Expand Down
23 changes: 15 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,27 @@ library(shapviz)
library(ggplot2)
library(xgboost)

set.seed(10)
set.seed(1)

# Build model
x <- c("carat", "cut", "color", "clarity")
dtrain <- xgb.DMatrix(data.matrix(diamonds[x]), label = diamonds$price)
fit <- xgb.train(params = list(learning_rate = 0.1), data = dtrain, nrounds = 65)
xvars <- c("log_carat", "cut", "color", "clarity")
X <- diamonds |>
transform(log_carat = log(carat)) |>
subset(select = xvars)

# Fit (untuned) model
fit <- xgb.train(
params = list(learning_rate = 0.1),
data = xgb.DMatrix(data.matrix(X), label = log(diamonds$price)),
nrounds = 65
)

# SHAP analysis: X can even contain factors
dia_2000 <- diamonds[sample(nrow(diamonds), 2000), x]
shp <- shapviz(fit, X_pred = data.matrix(dia_2000), X = dia_2000)
X_explain <- X[sample(nrow(X), 2000), ]
shp <- shapviz(fit, X_pred = data.matrix(X_explain), X = X_explain)

sv_importance(shp, show_numbers = TRUE)
sv_importance(shp, kind = "bee")
sv_dependence(shp, v = x) # patchwork
sv_dependence(shp, v = xvars) # patchwork
```

![](man/figures/README-imp.svg)
Expand Down
Binary file modified man/figures/README-bee.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/README-dep.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
363 changes: 71 additions & 292 deletions man/figures/README-imp.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/VIGNETTE-dep-ranger.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/VIGNETTE-dep.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/VIGNETTE-tidy-lgb-dep.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/VIGNETTE-tidy-lgb-imp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/VIGNETTE-tidy-rf-dep.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/VIGNETTE-tidy-rf-imp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/VIGNETTE-tidy-xgb-dep.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/VIGNETTE-tidy-xgb-imp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/VIGNETTE-tidy-xgb-inter.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
27 changes: 16 additions & 11 deletions vignettes/basic_use.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -62,26 +62,31 @@ Shiny diamonds... let's use XGBoost to model their prices by the four "C" variab
library(shapviz)
library(ggplot2)
library(xgboost)
library(patchwork) # We will need its "&" operator
set.seed(1)
# Build model
x <- c("carat", "cut", "color", "clarity")
dtrain <- xgb.DMatrix(data.matrix(diamonds[x]), label = diamonds$price, nthread = 1)
xvars <- c("log_carat", "cut", "color", "clarity")
X <- diamonds |>
transform(log_carat = log(carat)) |>
subset(select = xvars)
head(X)
# Fit (untuned) model
fit <- xgb.train(
params = list(learning_rate = 0.1, nthread = 1), data = dtrain, nrounds = 65
params = list(learning_rate = 0.1, nthread = 1),
data = xgb.DMatrix(data.matrix(X), label = log(diamonds$price), nthread = 1),
nrounds = 65
)
# SHAP analysis: X can even contain factors
dia_2000 <- diamonds[sample(nrow(diamonds), 2000), x]
shp <- shapviz(fit, X_pred = data.matrix(dia_2000), X = dia_2000)
X_explain <- X[sample(nrow(X), 2000), ]
shp <- shapviz(fit, X_pred = data.matrix(X_explain), X = X_explain)
sv_importance(shp, show_numbers = TRUE)
sv_importance(shp, kind = "beeswarm") # kind = "both" combines bar and bee
sv_importance(shp, kind = "beeswarm")
```
```{r, fig.width=8.5, fig.height=5.5}
sv_dependence(shp, v = x) # patchwork object
sv_dependence(shp, v = xvars) # patchwork object
```

### Decompose single predictions
Expand Down Expand Up @@ -112,9 +117,9 @@ Note that SHAP interaction values are multiplied by two (except main effects).

```{r, fig.width=8.5, fig.height=5.5}
shp_i <- shapviz(
fit, X_pred = data.matrix(dia_2000[x]), X = dia_2000, interactions = TRUE
fit, X_pred = data.matrix(X_explain), X = X_explain, interactions = TRUE
)
sv_dependence(shp_i, v = "carat", color_var = x, interactions = TRUE)
sv_dependence(shp_i, v = "log_carat", color_var = xvars, interactions = TRUE)
sv_interaction(shp_i) +
theme(axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1))
```
Expand Down
70 changes: 33 additions & 37 deletions vignettes/geographic.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -56,38 +56,38 @@ library(xgboost)
library(ggplot2)
library(shapviz)
head(miami)
miami <- miami |>
transform(
log_living = log(TOT_LVG_AREA),
log_land = log(LND_SQFOOT),
log_price = log(SALE_PRC)
)
x_coord <- c("LATITUDE", "LONGITUDE")
x_nongeo <- c("TOT_LVG_AREA", "LND_SQFOOT", "structure_quality", "age")
x <- c(x_coord, x_nongeo)
x_nongeo <- c("log_living", "log_land", "structure_quality", "age")
xvars <- c(x_coord, x_nongeo)
# Train/valid split
# Select training data
set.seed(1)
ix <- sample(nrow(miami), 0.8 * nrow(miami))
X_train <- data.matrix(miami[ix, x])
X_valid <- data.matrix(miami[-ix, x])
y_train <- log(miami$SALE_PRC[ix])
y_valid <- log(miami$SALE_PRC[-ix])
# Fit XGBoost model with early stopping
dtrain <- xgb.DMatrix(X_train, label = y_train, nthread = 1)
dvalid <- xgb.DMatrix(X_valid, label = y_valid, nthread = 1)
params <- list(
learning_rate = 0.2, objective = "reg:squarederror", max_depth = 5, nthread = 1
)
fit <- xgb.train(params = params, data = dtrain, nrounds = 200)
train <- miami[ix, ]
X_train <- train[xvars]
y_train <- train$log_price
# Fit XGBoost model
params <- list(learning_rate = 0.2, nthread = 1)
dtrain <- xgb.DMatrix(data.matrix(X_train), label = y_train, nthread = 1)
fit <- xgb.train(params, dtrain, nrounds = 200)
```

Let's first study selected SHAP dependence plots, evaluated on the validation dataset with around 2800 observations. Note that we could as well use (a subset of) the training data for this purpose.
Let's first study selected SHAP dependence plots for an explanation dataset of size 2000.

```{r}
sv <- shapviz(fit, X_pred = X_valid)
X_explain <- X_train[1:2000, ]
sv <- shapviz(fit, X_pred = data.matrix(X_explain))
sv_dependence(
sv,
v = c("TOT_LVG_AREA", "structure_quality", "LONGITUDE", "LATITUDE"),
v = c("log_living", "structure_quality", "LONGITUDE", "LATITUDE"),
alpha = 0.2
)
Expand Down Expand Up @@ -115,34 +115,30 @@ The second step leads to a model that is additive in each non-geographic compone
```{r}
# Extend the feature set
more_geo <- c("CNTR_DIST", "OCEAN_DIST", "RAIL_DIST", "HWY_DIST")
x2 <- c(x, more_geo)
xvars <- c(xvars, more_geo)
X_train <- train[xvars]
dtrain <- xgb.DMatrix(data.matrix(X_train), label = y_train, nthread = 1)
X_train2 <- data.matrix(miami[ix, x2])
X_valid2 <- data.matrix(miami[-ix, x2])
dtrain2 <- xgb.DMatrix(X_train2, label = y_train, nthread = 1)
dvalid2 <- xgb.DMatrix(X_valid2, label = y_valid, nthread = 1)
# Build interaction constraint vector
# Build interaction constraint vector and add it to params
ic <- c(
list(which(x2 %in% c(x_coord, more_geo)) - 1),
as.list(which(x2 %in% x_nongeo) - 1)
list(which(xvars %in% c(x_coord, more_geo)) - 1),
as.list(which(xvars %in% x_nongeo) - 1)
)
# Modify parameters
params$interaction_constraints <- ic
fit2 <- xgb.train(params = params, data = dtrain2, nrounds = 200)
# Fit XGBoost model
fit <- xgb.train(params, dtrain, nrounds = 200)
# SHAP analysis
sv2 <- shapviz(fit2, X_pred = X_valid2)
X_explain <- X_train[2:2000, ]
sv <- shapviz(fit, X_pred = data.matrix(X_explain))
# Two selected features: Thanks to additivity, structure_quality can be read as
# Ceteris Paribus
sv_dependence(sv2, v = c("structure_quality", "LONGITUDE"), alpha = 0.2)
sv_dependence(sv, v = c("structure_quality", "LONGITUDE"), alpha = 0.2)
# Total geographic effect (Ceteris Paribus thanks to additivity)
sv_dependence2D(sv2, x = "LONGITUDE", y = "LATITUDE", add_vars = more_geo) +
sv_dependence2D(sv, x = "LONGITUDE", y = "LATITUDE", add_vars = more_geo) +
coord_equal()
```

Expand Down
55 changes: 35 additions & 20 deletions vignettes/tidymodels.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,16 @@ library(shapviz)

set.seed(10)

splits <- initial_split(diamonds)
splits <- diamonds |>
transform(
log_price = log(price),
log_carat = log(carat)
) |>
initial_split()
df_train <- training(splits)

dia_recipe <- df_train |>
recipe(price ~ carat + color + clarity + cut)
recipe(log_price ~ log_carat + color + clarity + cut)

rf <- rand_forest(mode = "regression") |>
set_engine("ranger")
Expand All @@ -41,10 +46,10 @@ fit <- rf_wf |>
fit(df_train)

# SHAP analysis
xvars <- c("carat", "color", "clarity", "cut")
xvars <- c("log_carat", "color", "clarity", "cut")
X_explain <- df_train[1:1000, xvars] # Use only feature columns

# 90 seconds on laptop
# 1.5 minutes on laptop
# Note: If you have more than p=8 features, use kernelshap() instead of permshap()
system.time(
shap_values <- fit |>
Expand Down Expand Up @@ -78,18 +83,23 @@ of course, you don't *have* to work with SHAP interactions, especially if your m

**Remark:** Don't use 1:m transforms such as One-Hot-Encodings. They are usually not necessary and make the workflow more complicated. If you can't avoid this, check the `collapse` argument in `shapviz()`.

```
```r
library(tidymodels)
library(shapviz)
library(patchwork)

set.seed(10)

splits <- initial_split(diamonds)
splits <- diamonds |>
transform(
log_price = log(price),
log_carat = log(carat)
) |>
initial_split()
df_train <- training(splits)

dia_recipe <- df_train |>
recipe(price ~ carat + color + clarity + cut) |>
recipe(log_price ~ log_carat + color + clarity + cut) |>
step_integer(all_ordered())

# Should be tuned in practice
Expand Down Expand Up @@ -126,22 +136,22 @@ shap_values |>
# Absolute average SHAP interactions (off-diagonals already multiplied by 2)
shap_values |>
sv_interaction(kind = "no")
# carat clarity color cut
# carat 2998.30769 591.8859 425.63902 99.11383
# clarity 591.88589 632.2544 192.14847 25.47713
# color 425.63906 192.1484 424.91991 20.15823
# cut 99.11392 25.4771 20.15823 109.26374
# log_carat clarity color cut
# log_carat 0.87400688 0.067567245 0.032599394 0.024273852
# clarity 0.06756720 0.143393109 0.028236784 0.004910905
# color 0.03259941 0.028236796 0.095656042 0.004804729
# cut 0.02427382 0.004910904 0.004804732 0.031114735

# Usual dependence plot
xvars <- c("carat", "color", "clarity", "cut")
xvars <- c("log_carat", "color", "clarity", "cut")

shap_values |>
sv_dependence(xvars) &
plot_annotation("SHAP dependence plots") # patchwork magic

# SHAP interactions for carat
shap_values |>
sv_dependence("carat", color_var = xvars, interactions = TRUE) &
sv_dependence("log_carat", color_var = xvars, interactions = TRUE) &
plot_annotation("SHAP interactions for carat")
```
![](../man/figures/VIGNETTE-tidy-xgb-imp.png)
Expand All @@ -164,11 +174,16 @@ library(shapviz)

set.seed(10)

splits <- initial_split(diamonds)
splits <- diamonds |>
transform(
log_price = log(price),
log_carat = log(carat)
) |>
initial_split()
df_train <- training(splits)

dia_recipe <- df_train |>
recipe(price ~ carat + color + clarity + cut) |>
recipe(price ~ log_carat + color + clarity + cut) |>
step_integer(color, clarity) # we keep cut a factor (for illustration only)

# Should be tuned in practice
Expand All @@ -193,9 +208,9 @@ X_pred <- bake( # Goes to lightgbm:::predict.lgb.Booster()
bonsai:::prepare_df_lgbm()

head(X_pred, 2)
# carat color clarity cut
# [1,] 1.37 5 5 3
# [2,] 0.55 2 3 4
# log_carat color clarity cut
# [1,] 0.3148107 5 5 3
# [2,] -0.5978370 2 3 4

stopifnot(colnames(X_pred) %in% colnames(df_explain))

Expand All @@ -206,7 +221,7 @@ shap_values |>
sv_importance(show_numbers = TRUE)

shap_values |>
sv_dependence(c("carat", "color", "clarity", "cut"))
sv_dependence(c("log_carat", "color", "clarity", "cut"))
```
![](../man/figures/VIGNETTE-tidy-lgb-imp.png)
![](../man/figures/VIGNETTE-tidy-lgb-dep.png)
Expand Down

0 comments on commit 92f8153

Please sign in to comment.