Skip to content

Commit

Permalink
Lars/add MSEv eval criterion (#357)
Browse files Browse the repository at this point in the history
* Bugfix in `prepare_data()` related to vector of approaches. When using several approaches the old version only used the first approach. Verified this by adding a print in each prepare_data.approach() function and saw that only the first approach in internal$parameters$approach was used.
Can maybe remove code comments before pull request is accepted. Maybe a better method to get the approach?

Also updated roxygen2 for the function, as it seemed that it was reflecting the old version of shapr(?) due to arguments which are no longer present.

However, one then get a warning when creating the roxygen2 documentation. Discuss some solutions as comments below. Discuss with Martin.

* Implemented function that computes the MSEv evaluation criterion for all approaches as long as `internal$parameters$output_size == 1`.

Need to think about if it is applicable for vector of outputs.

* Plot the MSEv evaluation criterion.

This function is a separate plot function and is not part of the `shapr.plot()` function. It would maybe nice to make it a part of it and using, e.g., `plot_type = "MSEv". However, `make_MSEv_evaluation_criterion_plots()` handles list of explanation objects while `shapr.plot()` is restricted to a single shapr explanation object. Thus, one would need to rewrite the `shapr.plot()` to also handle mulitple objects.

* Added a section about the MSEv criterion in the vignette. This is s adraft and might need some polishing.

* # Lars have added `n_combinations` - 1 as a possibility, as the function `check_n_batches` threw an error for the vignette with gaussian approach with `n_combinations` = 8 and `n_batches = NULL`, as this function here then set `n_batches` = 10, which was too large. We subtract 1 as `check_n_batches` function specifies that `n_batches` must be strictly less than `n_combinations`.

* Samll typo.

* Fixed bug. All messages says "n_combinations is larger than or equal to 2^m", but all the test only tested for "larger than". I.e., if the user specified n_combinations = 2^m in the call to shapr::explain, the function would not treat it as exact.

* Changed to exclude the coalitions which are not effected by the approach by default.
Futhermore, there was a logical error.

* Added example to the roxygen. This is technically not needed as it is an internal function.
Could make a test out of the examples too.

* Added some parameters to the plot function to make it more flexiable and to make it easier for the user to controll what figures that are made

* Updated code as I had only changed the legend for the fill (which works with bars), but not the col (which works for lines/points).

* Added the updated namespace file. Maybe that is why the tests on github did not work beforehand.

* Added script demonstrating the bug that shapr does not enter the exact mode when `n_combinations = 2^m`, before the bugfix.

* Added (tentative) test that checks that shapr enters exact mode when `n_combinations >= 2^m`. Remove the large comment after discussing that with Martin.

* Added script that demonstrates the bug before the bugfix, and added test checking that we do not get an error when runing the code after the bugfix has been applied.

* Fixed lint warnings in `approach.R`.

* Added `ctree` in the example in roxygen

* Add manuals for the two funtions realted to MSEv criterion.

* Updated plot function (fixed some inconsistencies) and added detailed examples. Plan to make tests out of them.

* Started to make tests, but ran into a shapr bug.

* Added two parameters to the `internal$parameters` list which contains the number of approaches and the number of unique approaches.

This is for example useful to check that the provided `n_batches` is a valid value. (see next commits)

* Added test to check that `n_batches` must be larger than or equal to the number of unique approaches. Before the user could, e.g., set `n_batches = 2`, but use 4 approaches and then shapr would use 4 but not update `n_batches` and without giwing a warning to the user.

* Updated `get_default_n_batches` to take into consideration the number of unique approaches that is used. This was not done before and gave inconsistency in what number shapr would reccomend and use when `n_batches` was set to `null` by the user.

* Changed where seed is set such that it applies for both regular and combined approaches.
Furthermore, added if test, because previous version resulted in not reproducible code, as setting seed to `null` ruins that we set seed in `explain()`.

Just consider this small example:
# Set seed to get same values twice
set.seed(123)
rnorm(1)

# Seting the same seed gives the same value
set.seed(123)
rnorm(1)

# If we also include null then the seed is removed and we do not get the same value
set.seed(123)
set.seed(NULL)
rnorm(1)

# Setining seed to null actually gives a new "random" number each time.
set.seed(123)
set.seed(NULL)
rnorm(1)

* Typo

* Added test to check that setting the seed works for combined approaches.

* typo in test function

* Added file to demonstrate the bugs (before the bugfix)

* Added new test

* Updated tests by removing n_samples

* Added a bugfix to shapr not using the correct number of batches. Maybe not the most elegant solution.

* Updated the demonstration script

* Added last test and fixed lintr

* Lint again.

* styler

* minor edits to tests

* simplifies comment

* comb files ok

* Updated bug in independence approach related to categorical features which caused shapr to crash later. Added comments when I debuged to understand what was going on. I have added some comments about some stuff I did no understand/agree with. Discuss with Martin and correct this before merge.

* Updated bug in independence approach related to categorical features which caused shapr to crash later. Added comments when I debuged to understand what was going on. I have added some comments about some stuff I did no understand/agree with. Discuss with Martin and correct this before merge.

* lint warning

* Lint

* lint

* Updated roxygen

* Added plot test functions for MSEv criterion

* Ran styler

* Updated most lintr. I disagree that variable names should be less than 30 characters. Discuss with Martin. And it was Stylr that made the changes that caused the brace_linter warnings.

* Previous version would not test the output but rather that check that shapr would stop as `n_batches` was less than the length of `approch`.

* Added some extra parameters to the one test function.

* Updated some parameters in the MSEv test plots. Looked at all of them and they look to do what I want.

* SupressMessages to run `testthat::test_file` without messages, and changed tha `bar_plot_MSEv` such that we do not get printed `NULL` to the console when running the test.

* Add test images (get the same each time I run `testthat::test_file()`).

* Fixed lintr and styler

* Updated vignette with MSEv. Fixed lintr and ensured that the figures look like what they are supposed to.

* updated test files after accepting new values

* adjustments to comments and Lars' TODO-comments

* update snapshot file after weight adjustment

* cleaned up doc

* rerun doc

* style

* Changed to `n_batches = 10` in the combined approaches, as the previous value (`n_batches = 1`) is not allowed anymore as it is lower than the number of unique used approaches.

* Updated some messages in plot

* Minor updates to `make_MSEv_evaluation_criterion_plots`

* Update the manuals

* Updated the MSEv text in the vignette.

* accept OK test changes

* additonal Ok test files

* change batches in test files

* accept new files

* handle issue with a breaking change update in the testthat package

* + these

* removing last (unused) input of approach

* updating tests

* + update setup tests/snaps

* correcting unique length

* update linting and vignette

* update docs

* fix example issue

* temporary disable tests on older R systems

* remove unecessary if-else test

* data.table style on Lars's batch adjustment suggestion

* del comment

* lint

* snaps + test adjustment

* update plotting snaps to make tests pass

* update vignette

* .

* Renamed `compute_MSEv_evaluation_criterion()` to `compute_MSEv_eval_crit()`.

* Renamed `MSEv_evaluation_criterion_for_each_coalition` to `MSEv_eval_crit_each_comb`.

* Shorted function/variable names to less than 30 char.

* Renamed `MSEv_eval_crit_comb` to `MSEv_eval_crit_combination` for consistency.

* Added option for using Shapley weights in `compute_MSEv_eval_crit()`

* update snaps after name change

* update function names in tests

* new figure snaps

* update function names in vignette

* man + example bug

* Fixed places where variable names where not updated

* Manual updates

* Added `library(shapr)` in examples.

* Updated combined approach in `make_MSEv_eval_crit_plots` example to support new version.

* Export `compute_MSEv_eval_crit` to make `devtools::run_examples()` run.

* Added support for MSEv criterion in `explain()`.

* Updated MSEv criterion based on feedback.
Now internal.

* Added checks for MSEv criterion

* Refactored the plot function for MSEv based on feedback

* Updated vignette to reflect changes in MSEv criterion code.

* Fixed MSEv_combination. Now taking mean of it gives MSEv and we see the weighted importance of the combination on the overall precission when plotting.

* Removed default palette and refactored the code.

* Updated the vignette to reflect the changes in MSEv criterion.

* Set default MSEv values in `setup()` as `explain_forecast()` do not specify them.

* Added tests for `MSEv_skip_empty_full_comb` and `MSEv_uniform_comb_weights`. Matched the setup for other logical arguments.

* Added test for when MSEv criterion use the Shapley kernel weights

* Fixed bug in MSEv_plot function when specifying geom_col_width

* Fixed tests for MSEv plot function

* Added MSEv references

* Added Lars as Author

* Removed `MSEv_skip_empty_full_comb` form `explain` and all other functions/tests.

* Fixed inncorrect standard deviations of the mean. Now divides by sqrt(n_explain). And better documentation.

* MSEv_plots: introduced approximate CI, reduced example, improved documentation

* Updated MSEv_plot documentation

* Added example file for how to change the MSEv plots

* Inserted `level = 0.95` the places it where missing.

* Added CI explanation in vignette.

* Fix linting and stylr.

* Updated manuals (done autmatically by `devtools::run_examples()`)

* Added snaps plot svg files

* æ ø

* update test files

* del unused snaps

* typo

* man

* styler

* bugfix name in testfile

* fixing check warning of missing ... description

* fix check note on NSE

* Renamed `make_MSEv_eval_crit_plots` into `plot_MSEv_eval_crit`

* Changed default confidence level to check for the number of explicands

* Renamed the parameter `level` into `CI_level` in the `plot_MSEv_eval_crit` function.

* Updated plot_MSEv_eval_crit.Rd and namespace

* Introduced `plot_type` parameter in `plot_MSEv_eval_crit()` and updated title

* styler

* Updated manual

* Updated from ifelse to if else to support NULL return

* swap * + space with ~ in bquote to fix spacing in svg files

* update svg files

* man

* bugfix vignette + remove unecessary plot arguments

* fix NSE check issues

---------

Co-authored-by: Martin <[email protected]>
  • Loading branch information
LHBO and martinju authored Dec 11, 2023
1 parent 90afa9a commit 579724b
Show file tree
Hide file tree
Showing 69 changed files with 3,863 additions and 49 deletions.
4 changes: 2 additions & 2 deletions .Rprofile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#' @param ... Additional arguments passed to [waldo::compare()]
#' Gives the relative path to the test files to review
#'
snapshot_review_man <- function(path, ...) {
snapshot_review_man <- function(path, tolerance = NULL, ...) {
changed <- testthat:::snapshot_meta(path)
these_rds <- (tools::file_ext(changed$name) == "rds")
if (any(these_rds)) {
Expand All @@ -16,7 +16,7 @@ snapshot_review_man <- function(path, ...) {
new <- readRDS(changed[i, "new"])

cat(paste0("Difference for check ", changed[i, "name"], " in test ", changed[i, "test"], "\n"))
print(waldo::compare(old, new, max_diffs = 50, ...))
print(waldo::compare(old, new, max_diffs = 50, tolerance = tolerance, ...))
browser()
}
}
Expand Down
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ export(get_supported_approaches)
export(hat_matrix_cpp)
export(mahalanobis_distance_cpp)
export(observation_impute_cpp)
export(plot_MSEv_eval_crit)
export(predict_model)
export(prepare_data)
export(rss_cpp)
Expand Down Expand Up @@ -93,6 +94,9 @@ importFrom(stats,formula)
importFrom(stats,model.frame)
importFrom(stats,model.matrix)
importFrom(stats,predict)
importFrom(stats,pt)
importFrom(stats,qt)
importFrom(stats,sd)
importFrom(stats,setNames)
importFrom(utils,head)
importFrom(utils,methods)
Expand Down
11 changes: 10 additions & 1 deletion R/explain.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,14 @@
#' disabled for unsupported model classes.
#' Can also be used to override the default function for natively supported model classes.
#'
#' @param MSEv_uniform_comb_weights Logical. If `TRUE` (default), then the function weights the combinations
#' uniformly when computing the MSEv criterion. If `FALSE`, then the function use the Shapley kernel weights to
#' weight the combinations when computing the MSEv criterion. Note that the Shapley kernel weights are replaced by the
#' sampling frequency when not all combinations are considered.
#'
#' @param timing Logical.
#' Whether the timing of the different parts of the `explain()` should saved in the model object.
#' @param ... Further arguments passed to specific approaches
#'
#' @inheritDotParams setup_approach.empirical
#' @inheritDotParams setup_approach.independence
Expand Down Expand Up @@ -117,7 +123,8 @@
#' \describe{
#' \item{shapley_values}{data.table with the estimated Shapley values}
#' \item{internal}{List with the different parameters, data and functions used internally}
#' \item{pred_explain}{Numeric vector with the predictions for the explained observations.}
#' \item{pred_explain}{Numeric vector with the predictions for the explained observations}
#' \item{MSEv}{List with the values of the MSEv evaluation criterion for the approach.}
#' }
#'
#' `shapley_values` is a data.table where the number of rows equals
Expand Down Expand Up @@ -257,6 +264,7 @@ explain <- function(model,
keep_samp_for_vS = FALSE,
predict_model = NULL,
get_model_specs = NULL,
MSEv_uniform_comb_weights = TRUE,
timing = TRUE,
...) { # ... is further arguments passed to specific approaches

Expand Down Expand Up @@ -285,6 +293,7 @@ explain <- function(model,
seed = seed,
keep_samp_for_vS = keep_samp_for_vS,
feature_specs = feature_specs,
MSEv_uniform_comb_weights = MSEv_uniform_comb_weights,
timing = timing,
...
)
Expand Down
129 changes: 125 additions & 4 deletions R/finalize_explanation.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#' @export
finalize_explanation <- function(vS_list, internal) {
keep_samp_for_vS <- internal$parameters$keep_samp_for_vS
MSEv_uniform_comb_weights <- internal$parameters$MSEv_uniform_comb_weights

processed_vS_list <- postprocess_vS_list(
vS_list = vS_list,
Expand All @@ -24,20 +25,28 @@ finalize_explanation <- function(vS_list, internal) {

# internal$timing$shapley_computation <- Sys.time()


# Clearnig out the tmp list with model and predict_model (only added for AICc-types of empirical approach)
# Clearing out the tmp list with model and predict_model (only added for AICc-types of empirical approach)
internal$tmp <- NULL

internal$output <- processed_vS_list


output <- list(
shapley_values = dt_shapley,
internal = internal,
pred_explain = p
)
attr(output, "class") <- c("shapr", "list")

# Compute the MSEv evaluation criterion if the output of the predictive model is a scalar.
# TODO: check if it makes sense for output_size > 1.
if (internal$parameters$output_size == 1) {
output$MSEv <- compute_MSEv_eval_crit(
internal = internal,
dt_vS = processed_vS_list$dt_vS,
MSEv_uniform_comb_weights = MSEv_uniform_comb_weights
)
}

return(output)
}

Expand Down Expand Up @@ -104,7 +113,7 @@ get_p <- function(dt_vS, internal) {
#' Compute shapley values
#' @param explainer An `explain` object.
#' @param dt_vS The contribution matrix.
#' @return A `data.table` with shapley values for each test observation.
#' @return A `data.table` with Shapley values for each test observation.
#' @export
#' @keywords internal
compute_shapley_new <- function(internal, dt_vS) {
Expand Down Expand Up @@ -153,3 +162,115 @@ compute_shapley_new <- function(internal, dt_vS) {

return(dt_kshap)
}

#' Mean Squared Error of the Contribution Function `v(S)`
#'
#' @inheritParams explain
#' @inheritParams default_doc
#' @param dt_vS Data.table of dimension `n_combinations` times `n_explain + 1` containing the contribution function
#' estimates. The first column is assumed to be named `id_combination` and containing the ids of the combinations.
#' The last row is assumed to be the full combination, i.e., it contains the predicted responses for the observations
#' which are to be explained.
#' @param MSEv_skip_empty_full_comb Logical. If `TRUE` (default), we exclude the empty and grand
#' combinations/coalitions when computing the MSEv evaluation criterion. This is reasonable as they are identical
#' for all methods, i.e., their contribution function is independent of the used method as they are special cases not
#' effected by the used method. If `FALSE`, we include the empty and grand combinations/coalitions. In this situation,
#' we also recommend setting `MSEv_uniform_comb_weights = TRUE`, as otherwise the large weights for the empty and
#' grand combinations/coalitions will outweigh all other combinations and make the MSEv criterion uninformative.
#'
#' @return
#' List containing:
#' \describe{
#' \item{`MSEv`}{A \code{\link[data.table]{data.table}} with the overall MSEv evaluation criterion averaged
#' over both the combinations/coalitions and observations/explicands. The \code{\link[data.table]{data.table}}
#' also contains the standard deviation of the MSEv values for each explicand (only averaged over the combinations)
#' divided by the square root of the number of explicands.}
#' \item{`MSEv_explicand`}{A \code{\link[data.table]{data.table}} with the mean squared error for each
#' explicand, i.e., only averaged over the combinations/coalitions.}
#' \item{`MSEv_combination`}{A \code{\link[data.table]{data.table}} with the mean squared error for each
#' combination/coalition, i.e., only averaged over the explicands/observations.
#' The \code{\link[data.table]{data.table}} also contains the standard deviation of the MSEv values for
#' each combination divided by the square root of the number of explicands.}
#' }
#'
#' @description Function that computes the Mean Squared Error (MSEv) of the contribution function
#' v(s) as proposed by \href{https://arxiv.org/pdf/2006.01272.pdf}{Frye et al. (2019)} and used by
#' \href{https://www.jmlr.org/papers/volume23/21-1413/21-1413.pdf}{Olsen et al. (2022)}.
#'
#' @details
#' The MSEv evaluation criterion does not rely on access to the true contribution functions nor the
#' true Shapley values to be computed. A lower value indicates better approximations, however, the
#' scale and magnitude of the MSEv criterion is not directly interpretable in regard to the precision
#' of the final estimated Shapley values. \href{https://arxiv.org/pdf/2305.09536.pdf}{Olsen et al. (2022)}
#' illustrates in Figure 11 a fairly strong linear relationship between the MSEv criterion and the
#' MAE between the estimated and true Shapley values in a simulation study. Note that explicands
#' refer to the observations whose predictions we are to explain.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
compute_MSEv_eval_crit <- function(internal,
dt_vS,
MSEv_uniform_comb_weights,
MSEv_skip_empty_full_comb = TRUE) {
n_explain <- internal$parameters$n_explain
n_combinations <- internal$parameters$n_combinations
id_combination_indices <- if (MSEv_skip_empty_full_comb) seq(2, n_combinations - 1) else seq(1, n_combinations)
n_combinations_used <- length(id_combination_indices)
features <- internal$objects$X$features[id_combination_indices]

# Extract the predicted responses f(x)
p <- unlist(dt_vS[id_combination == n_combinations, -"id_combination"])

# Create contribution matrix
vS <- as.matrix(dt_vS[id_combination_indices, -"id_combination"])

# Square the difference between the v(S) and f(x)
dt_squared_diff_original <- sweep(vS, 2, p)^2

# Get the weights
averaging_weights <- if (MSEv_uniform_comb_weights) rep(1, n_combinations) else internal$objects$X$shapley_weight
averaging_weights <- averaging_weights[id_combination_indices]
averaging_weights_scaled <- averaging_weights / sum(averaging_weights)

# Apply the `averaging_weights_scaled` to each column (i.e., each explicand)
dt_squared_diff <- dt_squared_diff_original * averaging_weights_scaled

# Compute the mean squared error for each observation, i.e., only averaged over the coalitions.
# We take the sum as the weights sum to 1, so denominator is 1.
MSEv_explicand <- colSums(dt_squared_diff)

# The MSEv criterion for each coalition, i.e., only averaged over the explicands.
MSEv_combination <- rowMeans(dt_squared_diff * n_combinations_used)
MSEv_combination_sd <- apply(dt_squared_diff * n_combinations_used, 1, sd) / sqrt(n_explain)

# The MSEv criterion averaged over both the coalitions and explicands.
MSEv <- mean(MSEv_explicand)
MSEv_sd <- sd(MSEv_explicand) / sqrt(n_explain)

# Set the name entries in the arrays
names(MSEv_explicand) <- paste0("id_", seq(n_explain))
names(MSEv_combination) <- paste0("id_combination_", id_combination_indices)
names(MSEv_combination_sd) <- paste0("id_combination_", id_combination_indices)

# Convert the results to data.table
MSEv <- data.table(
"MSEv" = MSEv,
"MSEv_sd" = MSEv_sd
)
MSEv_explicand <- data.table(
"id" = seq(n_explain),
"MSEv" = MSEv_explicand
)
MSEv_combination <- data.table(
"id_combination" = id_combination_indices,
"features" = features,
"MSEv" = MSEv_combination,
"MSEv_sd" = MSEv_combination_sd
)

return(list(
MSEv = MSEv,
MSEv_explicand = MSEv_explicand,
MSEv_combination = MSEv_combination
))
}
Loading

0 comments on commit 579724b

Please sign in to comment.