Skip to content

Commit

Permalink
Harmonize batch distribution ++ (#359)
Browse files Browse the repository at this point in the history
* Bugfix in `prepare_data()` related to vector of approaches. When using several approaches the old version only used the first approach. Verified this by adding a print in each prepare_data.approach() function and saw that only the first approach in internal$parameters$approach was used.
Can maybe remove code comments before pull request is accepted. Maybe a better method to get the approach?

Also updated roxygen2 for the function, as it seemed that it was reflecting the old version of shapr(?) due to arguments which are no longer present.

However, one then get a warning when creating the roxygen2 documentation. Discuss some solutions as comments below. Discuss with Martin.

* # Lars have added `n_combinations` - 1 as a possibility, as the function `check_n_batches` threw an error for the vignette with gaussian approach with `n_combinations` = 8 and `n_batches = NULL`, as this function here then set `n_batches` = 10, which was too large. We subtract 1 as `check_n_batches` function specifies that `n_batches` must be strictly less than `n_combinations`.

* Samll typo.

* Fixed bug. All messages says "n_combinations is larger than or equal to 2^m", but all the test only tested for "larger than". I.e., if the user specified n_combinations = 2^m in the call to shapr::explain, the function would not treat it as exact.

* Added script demonstrating the bug that shapr does not enter the exact mode when `n_combinations = 2^m`, before the bugfix.

* Added (tentative) test that checks that shapr enters exact mode when `n_combinations >= 2^m`. Remove the large comment after discussing that with Martin.

* Added script that demonstrates the bug before the bugfix, and added test checking that we do not get an error when runing the code after the bugfix has been applied.

* Fixed lint warnings in `approach.R`.

* Added two parameters to the `internal$parameters` list which contains the number of approaches and the number of unique approaches.

This is for example useful to check that the provided `n_batches` is a valid value. (see next commits)

* Added test to check that `n_batches` must be larger than or equal to the number of unique approaches. Before the user could, e.g., set `n_batches = 2`, but use 4 approaches and then shapr would use 4 but not update `n_batches` and without giwing a warning to the user.

* Updated `get_default_n_batches` to take into consideration the number of unique approaches that is used. This was not done before and gave inconsistency in what number shapr would reccomend and use when `n_batches` was set to `null` by the user.

* Changed where seed is set such that it applies for both regular and combined approaches.
Furthermore, added if test, because previous version resulted in not reproducible code, as setting seed to `null` ruins that we set seed in `explain()`.

Just consider this small example:
# Set seed to get same values twice
set.seed(123)
rnorm(1)

# Seting the same seed gives the same value
set.seed(123)
rnorm(1)

# If we also include null then the seed is removed and we do not get the same value
set.seed(123)
set.seed(NULL)
rnorm(1)

# Setining seed to null actually gives a new "random" number each time.
set.seed(123)
set.seed(NULL)
rnorm(1)

* Typo

* Added test to check that setting the seed works for combined approaches.

* typo in test function

* Added file to demonstrate the bugs (before the bugfix)

* Added new test

* Updated tests by removing n_samples

* Added a bugfix to shapr not using the correct number of batches. Maybe not the most elegant solution.

* Updated the demonstration script

* Added last test and fixed lintr

* Lint again.

* styler

* minor edits to tests

* simplifies comment

* comb files ok

* Updated bug in independence approach related to categorical features which caused shapr to crash later. Added comments when I debuged to understand what was going on. I have added some comments about some stuff I did no understand/agree with. Discuss with Martin and correct this before merge.

* Updated bug in independence approach related to categorical features which caused shapr to crash later. Added comments when I debuged to understand what was going on. I have added some comments about some stuff I did no understand/agree with. Discuss with Martin and correct this before merge.

* lint warning

* Lint

* lint

* updated test files after accepting new values

* adjustments to comments and Lars' TODO-comments

* update snapshot file after weight adjustment

* cleaned up doc

* rerun doc

* style

* Changed to `n_batches = 10` in the combined approaches, as the previous value (`n_batches = 1`) is not allowed anymore as it is lower than the number of unique used approaches.

* accept OK test changes

* additonal Ok test files

* change batches in test files

* accept new files

* handle issue with a breaking change update in the testthat package

* + these

* removing last (unused) input of approach

* updating tests

* + update setup tests/snaps

* correcting unique length

* update linting and vignette

* update docs

* fix example issue

* temporary disable tests on older R systems

* remove unecessary if-else test

* data.table style on Lars's batch adjustment suggestion

* del comment

* lint

---------

Co-authored-by: Martin <[email protected]>
  • Loading branch information
LHBO and martinju authored Nov 20, 2023
1 parent a62005f commit 90afa9a
Show file tree
Hide file tree
Showing 49 changed files with 2,521 additions and 238 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ jobs:
- {os: windows-latest, r: 'release'}
- {os: ubuntu-20.04, r: 'devel', http-user-agent: 'release'}
- {os: ubuntu-20.04, r: 'release'}
- {os: ubuntu-20.04, r: 'oldrel-1'}
- {os: ubuntu-20.04, r: 'oldrel-2'}
# Temporary disable the below check plattforms as they fail due to a change in how R reports error from R<4.3 to R>=4.3,
# which gives a different output in the snapshots produced by testthat>=3.2.0
# - {os: ubuntu-20.04, r: 'oldrel-1'}
# - {os: ubuntu-20.04, r: 'oldrel-2'}

env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
Expand Down
11 changes: 6 additions & 5 deletions R/explain.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
#' can still be explained by passing `predict_model` and (optionally) `get_model_specs`,
#' see details for more information.
#'
#' @param approach Character vector of length `1` or `n_features`.
#' `n_features` equals the total number of features in the model. All elements should,
#' @param approach Character vector of length `1` or one less than the number of features.
#' All elements should,
#' either be `"gaussian"`, `"copula"`, `"empirical"`, `"ctree"`, `"categorical"`, `"timeseries"`, or `"independence"`.
#' See details for more information.
#'
Expand Down Expand Up @@ -101,9 +101,10 @@
#' and you'd like to use the `"gaussian"` approach when you condition on a single feature,
#' the `"empirical"` approach if you condition on 2-5 features, and `"copula"` version
#' if you condition on more than 5 features this can be done by simply passing
#' `approach = c("gaussian", rep("empirical", 4), rep("copula", 5))`. If
#' `approach = c("gaussian", rep("empirical", 4), rep("copula", 4))`. If
#' `"approach[i]" = "gaussian"` means that you'd like to use the `"gaussian"` approach
#' when conditioning on `i` features.
#' when conditioning on `i` features. Conditioning on all features needs no approach as that is given
#' by the complete prediction itself, and should thus not be part of the vector.
#'
#' For `approach="ctree"`, `n_samples` corresponds to the number of samples
#' from the leaf node (see an exception related to the `sample` argument).
Expand Down Expand Up @@ -203,7 +204,7 @@
#' )
#'
#' # Combined approach
#' approach <- c("gaussian", "gaussian", "empirical", "empirical")
#' approach <- c("gaussian", "gaussian", "empirical")
#' explain5 <- explain(
#' model = model,
#' x_explain = x_explain,
Expand Down
31 changes: 22 additions & 9 deletions R/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ check_n_batches <- function(internal) {
n_combinations <- internal$parameters$n_combinations
is_groupwise <- internal$parameters$is_groupwise
n_groups <- internal$parameters$n_groups
n_unique_approaches <- internal$parameters$n_unique_approaches

if (!is_groupwise) {
actual_n_combinations <- ifelse(is.null(n_combinations), 2^n_features, n_combinations)
Expand All @@ -217,10 +218,17 @@ check_n_batches <- function(internal) {

if (n_batches >= actual_n_combinations) {
stop(paste0(
"`n_batches` (", n_batches, ") must be smaller than the number feature combinations/`n_combinations` (",
"`n_batches` (", n_batches, ") must be smaller than the number of feature combinations/`n_combinations` (",
actual_n_combinations, ")"
))
}

if (n_batches < n_unique_approaches) {
stop(paste0(
"`n_batches` (", n_batches, ") must be larger than the number of unique approaches in `approach` (",
n_unique_approaches, ")."
))
}
}


Expand Down Expand Up @@ -368,6 +376,10 @@ get_extra_parameters <- function(internal) {
internal$parameters$n_groups <- NULL
}

# Get the number of unique approaches
internal$parameters$n_approaches <- length(internal$parameters$approach)
internal$parameters$n_unique_approaches <- length(unique(internal$parameters$approach))

return(internal)
}

Expand Down Expand Up @@ -658,13 +670,14 @@ check_approach <- function(internal) {
supported_approaches <- get_supported_approaches()

if (!(is.character(approach) &&
(length(approach) == 1 || length(approach) == n_features) &&
(length(approach) == 1 || length(approach) == n_features - 1) &&
all(is.element(approach, supported_approaches)))
) {
stop(
paste(
"`approach` must be one of the following: \n", paste0(supported_approaches, collapse = ", "), "\n",
"or a vector of length equal to the number of features (", n_features, ") with only the above strings."
"or a vector of length one less than the number of features (", n_features - 1, "),",
"with only the above strings."
)
)
}
Expand All @@ -675,33 +688,33 @@ set_defaults <- function(internal) {
# Set defaults for certain arguments (based on other input)

approach <- internal$parameters$approach
n_unique_approaches <- internal$parameters$n_unique_approaches
used_n_combinations <- internal$parameters$used_n_combinations
n_batches <- internal$parameters$n_batches

# n_batches
if (is.null(n_batches)) {
internal$parameters$n_batches <- get_default_n_batches(approach, used_n_combinations)
internal$parameters$n_batches <- get_default_n_batches(approach, n_unique_approaches, used_n_combinations)
}

return(internal)
}

#' @keywords internal
get_default_n_batches <- function(approach, n_combinations) {
get_default_n_batches <- function(approach, n_unique_approaches, n_combinations) {
used_approach <- names(sort(table(approach), decreasing = TRUE))[1] # Most frequent used approach (when more present)

if (used_approach %in% c("ctree", "gaussian", "copula")) {
suggestion <- ceiling(n_combinations / 10)
this_min <- 10
this_max <- 1000
min_checked <- max(c(this_min, suggestion))
ret <- min(c(this_max, min_checked))
} else {
suggestion <- ceiling(n_combinations / 100)
this_min <- 2
this_max <- 100
min_checked <- max(c(this_min, suggestion))
ret <- min(c(this_max, min_checked))
}
min_checked <- max(c(this_min, suggestion, n_unique_approaches))
ret <- min(c(this_max, min_checked, n_combinations - 1))
message(
paste0(
"Setting parameter 'n_batches' to ", ret, " as a fair trade-off between memory consumption and ",
Expand Down
20 changes: 19 additions & 1 deletion R/setup_computation.R
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,7 @@ create_S_batch_new <- function(internal, seed = NULL) {

X <- internal$objects$X

if (!is.null(seed)) set.seed(seed)

if (length(approach0) > 1) {
X[!(n_features %in% c(0, n_features0)), approach := approach0[n_features]]
Expand All @@ -632,6 +633,24 @@ create_S_batch_new <- function(internal, seed = NULL) {
pmax(1, round(.N / (n_combinations - 2) * n_batches)),
n_S_per_approach = .N
), by = approach]

# Ensures that the number of batches corresponds to `n_batches`
if (sum(batch_count_dt$n_batches_per_approach) != n_batches) {
# Ensure that the number of batches is not larger than `n_batches`.
# Remove one batch from the approach with the most batches.
while (sum(batch_count_dt$n_batches_per_approach) > n_batches) {
batch_count_dt[which.max(n_batches_per_approach),
n_batches_per_approach := n_batches_per_approach - 1]
}

# Ensure that the number of batches is not lower than `n_batches`.
# Add one batch to the approach with most coalitions per batch
while (sum(batch_count_dt$n_batches_per_approach) < n_batches) {
batch_count_dt[which.max(n_S_per_approach / n_batches_per_approach),
n_batches_per_approach := n_batches_per_approach + 1]
}
}

batch_count_dt[, n_leftover_first_batch := n_S_per_approach %% n_batches_per_approach]
data.table::setorder(batch_count_dt, -n_leftover_first_batch)

Expand All @@ -640,7 +659,6 @@ create_S_batch_new <- function(internal, seed = NULL) {

# Randomize order before ordering spreading the batches on the different approaches as evenly as possible
# with respect to shapley_weight
set.seed(seed)
X[, randomorder := sample(.N)]
data.table::setorder(X, randomorder) # To avoid smaller id_combinations always proceeding large ones
data.table::setorder(X, shapley_weight)
Expand Down
139 changes: 139 additions & 0 deletions inst/scripts/devel/demonstrate_combined_approaches_bugs.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Use the data objects from the helper-lm.R file.
# Here we want to illustrate three bugs related to combined approaches (before the bugfix)


# First we see that setting `n_batches` lower than the number of unique approaches
# produce some inconsistencies in shapr.
# After the bugfix, we force the user to choose a valid value for `n_batches`.
explanation_1 = explain(
model = model_lm_numeric,
x_explain = x_explain_numeric,
x_train = x_train_numeric,
approach = c("independence", "empirical", "gaussian", "copula", "empirical"),
prediction_zero = p0,
n_batches = 3,
timing = FALSE,
seed = 1)

# It says shapr is using 3 batches
explanation_1$internal$parameters$n_batches

# But shapr has actually used 4.
# This is because shapr can only handle one type of approach for each batch.
# Hence, the number of batches must be at least as large as the number of unique approaches.
# (excluding the last approach which is not used, as we then condition on all features)
length(explanation_1$internal$objects$S_batch)

# Note that after the bugfix, we give an error if `n_batches` < # unique approaches.





# Second we look at at another situation where # unique approaches is two and we set `n_batches` = 2,
# but shapr still use three batches. This is due to how shapr decides how many batches each approach
# should get. Right now it decided based on the proportion of the number of coalitions each approach
# is responsible. In this setting, independence is responsible for 5 coalitions and ctree for 25 coalitions,
# So, initially shapr sets that ctree should get the two batches while independence gets 0, but this
# is than changed to 1 without considering that it now breaks the consistency with the `n_batches`.
# This is done in the function `create_S_batch_new()` in setup_computation.R.
explanation_2 = explain(
model = model_lm_numeric,
x_explain = x_explain_numeric,
x_train = x_train_numeric,
approach = c("independence", "ctree", "ctree", "ctree" ,"ctree"),
prediction_zero = p0,
n_batches = 2,
timing = FALSE,
seed = 1)

# It says shapr is using 2 batches
explanation_2$internal$parameters$n_batches

# But shapr has actually used 3
length(explanation_2$internal$objects$S_batch)

# These are equal after the bugfix


# Same type of bug but in the opposite direction
explanation_3 = explain(
model = model_lm_numeric,
x_explain = x_explain_numeric,
x_train = x_train_numeric,
approach = c("independence", "ctree", "ctree", "ctree" ,"ctree"),
prediction_zero = p0,
n_batches = 15,
timing = FALSE,
seed = 1)

# It says shapr is using 15 batches
explanation_3$internal$parameters$n_batches

# It says shapr is using 14 batches
length(explanation_3$internal$objects$S_batch)

# These are equal after the bugfix






# Bug number three caused shapr to not to be reproducible as seting the seed did not work for combined approaches.
# This was due to a `set.seed(NULL)` which ruins all of the earlier set.seed procedures.


# Check that setting the seed works for a combination of approaches
# Here `n_batches` is set to `4`, so one batch for each method,
# i.e., no randomness.
# In the first example we get no bug as there is no randomness in assigning the batches.
explanation_combined_1 = explain(
model = model_lm_numeric,
x_explain = x_explain_numeric,
x_train = x_train_numeric,
approach = c("independence", "empirical", "gaussian", "copula", "empirical"),
prediction_zero = p0,
timing = FALSE,
seed = 1)

explanation_combined_2 = explain(
model = model_lm_numeric,
x_explain = x_explain_numeric,
x_train = x_train_numeric,
approach = c("independence", "empirical", "gaussian", "copula", "empirical"),
prediction_zero = p0,
timing = FALSE,
seed = 1)

# Check that they are equal
all.equal(explanation_combined_1, explanation_combined_2)


# Here `n_batches` is set to `10`, so NOT one batch for each method,
# i.e., randomness in assigning the batches.
explanation_combined_3 = explain(
model = model_lm_numeric,
x_explain = x_explain_numeric,
x_train = x_train_numeric,
approach = c("independence", "empirical", "gaussian", "copula", "ctree"),
prediction_zero = p0,
timing = FALSE,
seed = 1)

explanation_combined_4 = explain(
model = model_lm_numeric,
x_explain = x_explain_numeric,
x_train = x_train_numeric,
approach = c("independence", "empirical", "gaussian", "copula", "ctree"),
prediction_zero = p0,
timing = FALSE,
seed = 1)

# Check that they are not equal
all.equal(explanation_combined_3, explanation_combined_4)
explanation_combined_3$internal$objects$X
explanation_combined_4$internal$objects$X

# These are equal after the bugfix

54 changes: 54 additions & 0 deletions inst/scripts/devel/testing_for_valid_defualt_n_batches.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# In this code we demonstrate that (before the bugfix) the `explain()` function
# does not enter the exact mode when n_combinations is larger than or equal to 2^m.
# The mode is only changed if n_combinations is strictly larger than 2^m.
# This means that we end up with using all coalitions when n_combinations is 2^m,
# but use not the exact Shapley kernel weights.
# Bugfix replaces `>` with `=>`in the places where the code tests if
# n_combinations is larger than or equal to 2^m. Then the text/messages printed by
# shapr and the code correspond.

library(xgboost)
library(data.table)

data("airquality")
data <- data.table::as.data.table(airquality)
data <- data[complete.cases(data), ]

x_var <- c("Solar.R", "Wind", "Temp", "Month")
y_var <- "Ozone"

ind_x_explain <- 1:6
x_train <- data[-ind_x_explain, ..x_var]
y_train <- data[-ind_x_explain, get(y_var)]
x_explain <- data[ind_x_explain, ..x_var]

# Fitting a basic xgboost model to the training data
model <- xgboost::xgboost(
data = as.matrix(x_train),
label = y_train,
nround = 20,
verbose = FALSE
)

# Specifying the phi_0, i.e. the expected prediction without any features
p0 <- mean(y_train)

# Shapr sets the default number of batches to be 10 for this dataset for the
# "ctree", "gaussian", and "copula" approaches. Thus, setting `n_combinations`
# to any value lower of equal to 10 causes the error.
any_number_equal_or_below_10 = 8

# Before the bugfix, shapr:::check_n_batches() throws the error:
# Error in check_n_batches(internal) :
# `n_batches` (10) must be smaller than the number feature combinations/`n_combinations` (8)
# Bug only occures for "ctree", "gaussian", and "copula" as they are treated different in
# `get_default_n_batches()`, I am not certain why. Ask Martin about the logic behind that.
explanation <- explain(
model = model,
x_explain = x_explain,
x_train = x_train,
n_samples = 2, # Low value for fast computations
approach = "gaussian",
prediction_zero = p0,
n_combinations = any_number_equal_or_below_10
)
Loading

0 comments on commit 90afa9a

Please sign in to comment.