Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update python examples ++ #416

Merged
merged 5 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions inst/code_paper/code_sec_3.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ library(xgboost)
library(data.table)
library(shapr)

path <- "inst/code_paper/"
path0 <- "https://raw.githubusercontent.com/NorskRegnesentral/shapr/refs/heads/"
path <- paste0(path0,"master/inst/code_paper/")
x_explain <- fread(paste0(path, "x_explain.csv"))
x_train <- fread(paste0(path, "x_train.csv"))
y_train <- unlist(fread(paste0(path, "y_train.csv")))
model <- readRDS(paste0(path, "model.rds"))
model <- readRDS(file(paste0(path, "model.rds")))


# We compute the SHAP values for the test data.
Expand Down Expand Up @@ -51,8 +52,7 @@ exp_20_ctree$MSEv$MSEv
#<num> <num>
# 1: 1224818 101680.4

exp_20_ctree

print(exp_20_ctree)
### Continued estimation

exp_iter_ctree <- explain(model = model,
Expand All @@ -71,7 +71,7 @@ library(ggplot2)

plot(exp_iter_ctree, plot_type = "scatter",scatter_features = c("atemp","windspeed"))

ggplot2::ggsave("inst/code_paper/scatter_ctree.pdf",width = 7, height = 4)
ggplot2::ggsave("inst/code_paper/scatter_ctree.pdf",width = 7, height = 3)

### Grouping

Expand Down Expand Up @@ -125,7 +125,10 @@ exp_g_reg_tuned$MSEv$MSEv

# Plot the best one

plot(exp_group_reg_sep_xgb_tuned,index_x_explain = 6,plot_type="waterfall")
exp_g_reg_tuned$shapley_values_est[6,]
x_explain[6,]

plot(exp_g_reg_tuned,index_x_explain = 6,plot_type="waterfall")

ggplot2::ggsave("inst/code_paper/waterfall_group.pdf",width = 7, height = 4)

Expand Down
3 changes: 2 additions & 1 deletion inst/code_paper/code_sec_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import pandas as pd
from shaprpy import explain

path = "inst/code_paper/"
path0 = "https://raw.githubusercontent.com/NorskRegnesentral/shapr/refs/heads/"
path = path0 + "master/inst/code_paper/"

# Read data
x_train = pd.read_csv(path + "x_train.csv")
Expand Down
12 changes: 7 additions & 5 deletions inst/code_paper/code_sec_6.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ library(xgboost)
library(data.table)
library(shapr)

path <- "inst/code_paper/"
path0 <- "https://raw.githubusercontent.com/NorskRegnesentral/shapr/refs/heads/"
path <- paste0(path0,"master/inst/code_paper/")
x_full <- fread(paste0(path, "x_full.csv"))

data_fit <- x_full[seq_len(729), ]

model_ar <- ar(x_full$temp, order = 2)
model_ar <- ar(data_fit$temp, order = 2)

phi0_ar <- rep(mean(x_full$temp), 3)
phi0_ar <- rep(mean(data_fit$temp), 3)

explain_forecast(
model = model_ar,
Expand All @@ -30,8 +32,8 @@ phi0_arimax <- rep(mean(data_fit$temp), 2)

explain_forecast(
model = model_arimax,
y = data_fit[, "temp"],
xreg = bike[, "windspeed"],
y = x_full[, "temp"],
xreg = x_full[, "windspeed"],
train_idx = 2:728,
explain_idx = 729,
explain_y_lags = 2,
Expand Down
Binary file modified inst/code_paper/scatter_ctree.pdf
Binary file not shown.
4 changes: 2 additions & 2 deletions python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ model = RandomForestRegressor()
model.fit(dfx_train, dfy_train.values.flatten())

## Shapr
df_shapley, pred_explain, internal, timing = explain(
explanation = explain(
model = model,
x_train = dfx_train,
x_explain = dfx_test,
approach = 'empirical',
phi0 = dfy_train.mean().item(),
)
print(df_shapley)
print(explanation["shapley_values_est"])
```

`shaprpy` knows how to explain predictions from models from `sklearn`, `keras` and `xgboost`.
Expand Down
35 changes: 17 additions & 18 deletions python/examples/keras_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,34 +25,33 @@
epochs=10,
validation_data=(dfx_test, dfy_test))
## Shapr
df_shapley, pred_explain, internal, timing, MSEv = explain(
explanation = explain(
model = model,
x_train = dfx_train,
x_explain = dfx_test,
approach = 'empirical',
phi0 = dfy_train.mean().item(),
phi0 = dfy_train.mean().item()
)
print(df_shapley)
print(explanation["shapley_values_est"])

"""
none sepal length (cm) sepal width (cm) petal length (cm) \
1 0.494737 0.042263 0.037911 0.059232
2 0.494737 0.034217 0.029183 0.045027
3 0.494737 0.045776 0.031752 0.058278
4 0.494737 0.014977 0.032691 0.014280
5 0.494737 0.022742 0.025851 0.027427

petal width (cm)
1 0.058412
2 0.053639
3 0.070650
4 0.018697
5 0.026814

explain_id none sepal length (cm) sepal width (cm) \
1 1 0.494737 0.041518 0.037129
2 2 0.494737 0.033541 0.028414
3 3 0.494737 0.045033 0.031092
4 4 0.494737 0.014281 0.031831
5 5 0.494737 0.022155 0.025154

petal length (cm) petal width (cm)
1 0.058252 0.057664
2 0.044242 0.052839
3 0.057368 0.069891
4 0.013667 0.018016
5 0.026672 0.026181
"""

# Look at the (overall) MSEv
MSEv["MSEv"]
explanation["MSEv"]["MSEv"]

"""
MSEv MSEv_sd
Expand Down
52 changes: 35 additions & 17 deletions python/examples/pytorch_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,33 +36,51 @@ def forward(self, x):
optim.zero_grad()

## Shapr
df_shapley, pred_explain, internal, timing, MSEv = explain(
explanation = explain(
model = model,
x_train = dfx_train,
x_explain = dfx_test,
approach = 'empirical',
predict_model = lambda m, x: m(torch.from_numpy(x.values).float()).cpu().detach().numpy(),
phi0 = dfy_train.mean().item(),
)
print(df_shapley)
print(explanation["shapley_values_est"])
"""
none MedInc HouseAge AveRooms AveBedrms Population AveOccup \
1 2.205947 2.313935 5.774470 5.425240 4.194669 1.712164 3.546001
2 2.205947 4.477620 5.467266 2.904239 3.046492 1.484807 5.631292
3 2.205946 4.028013 1.168401 5.229893 1.719724 2.134012 3.426378
4 2.205948 4.230376 8.639265 1.138520 3.776463 3.786978 4.253034
5 2.205947 3.923747 1.483737 1.113199 4.963213 -3.645875 4.950775
explain_id none MedInc HouseAge AveRooms AveBedrms Population \
1 1 2.205951 3.531437 7.746453 6.985043 5.454877 3.287326
2 2 2.205951 6.004403 7.041080 4.254553 4.118677 3.162567
3 3 2.205950 5.497648 1.538680 6.750968 2.806428 3.687014
4 4 2.205951 5.761901 11.378609 2.112351 5.013451 5.754630
5 5 2.205951 5.325281 2.585713 2.224409 6.418153 -2.848570

Latitude Longitude
1 1.102239 2.906469
2 4.966465 2.178510
3 3.503413 2.909760
4 3.413727 3.795563
5 3.011126 4.016985
AveOccup Latitude Longitude
1 4.774873 2.273699 4.314784
2 7.386783 6.473623 3.318631
3 5.193341 4.875864 4.290797
4 5.866562 4.564957 5.139962
5 6.428984 4.280456 5.509226
"""

MSEv["MSEv"]
print(explanation["shapley_values_sd"])

"""
explain_id none MedInc HouseAge AveRooms AveBedrms \
1 1 3.523652e-08 0.122568 0.124885 0.163694 0.134910
2 2 3.501778e-08 0.125286 0.113064 0.123057 0.129869
3 3 1.805247e-08 0.098208 0.095959 0.115399 0.102265
4 4 3.227380e-08 0.110442 0.118524 0.124688 0.101476
5 5 3.650380e-08 0.125538 0.130427 0.136797 0.131515

Population AveOccup Latitude Longitude
1 0.133510 0.149141 0.132394 0.121605
2 0.113429 0.124539 0.122773 0.100871
3 0.092633 0.110790 0.090657 0.090542
4 0.114721 0.122266 0.103081 0.105613
5 0.113853 0.139291 0.135377 0.132476
"""

explanation["MSEv"]["MSEv"]
"""
MSEv MSEv_sd
1 27.046126 7.253933
MSEv MSEv_sd
1 33.143896 7.986808
"""
55 changes: 11 additions & 44 deletions python/examples/regression_paradigm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
x_train=dfx_train,
x_explain=dfx_test,
approach='empirical',
iterative = False,
phi0=dfy_train.mean().item()
)

Expand All @@ -38,8 +39,6 @@
x_explain=dfx_test,
approach='regression_separate',
phi0=dfy_train.mean().item(),
verbose=2,
n_batches=1,
regression_model='parsnip::linear_reg()'
)

Expand All @@ -50,8 +49,6 @@
x_explain=dfx_test,
approach='regression_separate',
phi0=dfy_train.mean().item(),
verbose=2,
n_batches=1,
regression_model='parsnip::linear_reg()',
regression_recipe_func='''function(regression_recipe) {
return(recipes::step_ns(regression_recipe, recipes::all_numeric_predictors(), deg_free = 3))
Expand All @@ -65,8 +62,6 @@
x_explain=dfx_test,
approach='regression_separate',
phi0=dfy_train.mean().item(),
verbose=2,
n_batches=1,
regression_model='parsnip::linear_reg()',
regression_recipe_func='''function(regression_recipe) {
return(recipes::step_ns(regression_recipe, recipes::all_numeric_predictors(), deg_free = 3))
Expand All @@ -80,8 +75,6 @@
x_explain=dfx_test,
approach='regression_separate',
phi0=dfy_train.mean().item(),
verbose=2,
n_batches=1,
regression_model="parsnip::decision_tree(tree_depth = hardhat::tune(), engine = 'rpart', mode = 'regression')",
regression_tune_values='dials::grid_regular(dials::tree_depth(), levels = 4)',
regression_vfold_cv_para={'v': 5}
Expand All @@ -94,8 +87,6 @@
x_explain=dfx_test,
approach='regression_separate',
phi0=dfy_train.mean().item(),
verbose=2,
n_batches=1,
regression_model="parsnip::boost_tree(engine = 'xgboost', mode = 'regression')"
)

Expand All @@ -106,8 +97,6 @@
x_explain=dfx_test,
approach='regression_separate',
phi0=dfy_train.mean().item(),
verbose=2,
n_batches=1,
regression_model="parsnip::boost_tree(trees = hardhat::tune(), engine = 'xgboost', mode = 'regression')",
regression_tune_values='expand.grid(trees = c(10, 15, 25, 50, 100, 500))',
regression_vfold_cv_para={'v': 5}
Expand All @@ -121,8 +110,6 @@
x_explain=dfx_test,
approach='regression_surrogate',
phi0=dfy_train.mean().item(),
verbose=2,
n_batches=1,
regression_model='parsnip::linear_reg()'
)

Expand All @@ -133,8 +120,6 @@
x_explain=dfx_test,
approach='regression_surrogate',
phi0=dfy_train.mean().item(),
verbose=2,
n_batches=1,
regression_model="parsnip::rand_forest(engine = 'ranger', mode = 'regression')"
)

Expand All @@ -145,8 +130,6 @@
x_explain=dfx_test,
approach='regression_surrogate',
phi0=dfy_train.mean().item(),
verbose=2,
n_batches=1,
regression_model="""parsnip::rand_forest(
mtry = hardhat::tune(), trees = hardhat::tune(), engine = 'ranger', mode = 'regression'
)""",
Expand All @@ -161,34 +144,18 @@
# Print the MSEv evaluation criterion scores
print("Method", "MSEv", "Elapsed time (seconds)")
for i, (method, explanation) in enumerate(explanation_list.items()):
print(method, round(explanation[4]["MSEv"]["MSEv"].iloc[0], 3), round(explanation[3]["total_time_secs"], 3))
print(method, round(explanation["MSEv"]["MSEv"].iloc[0].iloc[0], 3), round(explanation["timing"]["total_time_secs"][0], 3))



"""
Method MSEv Time
empirical 0.826 1.096
sep_lm 1.623 12.093
sep_pca 1.626 16.435
sep_splines 1.626 15.072
sep_tree_cv 1.436 275.002
sep_xgboost 0.769 13.870
sep_xgboost_cv 0.802 312.758
sur_lm 1.772 0.548
sur_rf 0.886 41.250
"""

explanation_list["sep_xgboost"][0]
explanation_list["sep_xgboost"]["shapley_values_est"]

"""
none MedInc HouseAge AveRooms AveBedrms Population AveOccup \
1 2.205937 -0.496421 0.195272 -0.077923 0.010124 -0.219369 -0.316029
2 2.205938 -0.163246 0.014565 -0.415945 -0.114073 0.084315 0.144754
3 2.205938 0.574157 0.258926 0.090818 -0.665126 0.354005 0.869530
4 2.205938 0.311416 -0.105142 0.211300 0.031939 -0.180331 -0.059839
5 2.205938 0.077537 -0.150997 -0.117875 0.087118 -0.085118 0.414764
Latitude Longitude
1 -0.434240 -0.361774
2 -0.483618 -0.324016
3 0.276002 0.957242
4 0.028560 0.049815
5 -0.242943 0.006815
explain_id none MedInc HouseAge AveRooms AveBedrms Population AveOccup Latitude Longitude
1 1 2.205937 -0.498764 0.193443 -0.073068 0.005078 -0.216733 -0.313781 -0.433844 -0.362689
2 2 2.205938 -0.160032 0.014564 -0.417670 -0.117127 0.084102 0.151612 -0.486576 -0.326138
3 3 2.205938 0.585638 0.239399 0.103826 -0.656533 0.349671 0.859701 0.275356 0.958495
4 4 2.205938 0.311038 -0.114403 0.206639 0.041748 -0.178090 -0.061004 0.036681 0.045110
5 5 2.205938 0.079439 -0.156861 -0.118913 0.093746 -0.097861 0.433192 -0.239588 -0.003852
"""
Loading
Loading