Skip to content

Commit

Permalink
Fixed bug with "keep_samp_for_vS"… (#349)
Browse files Browse the repository at this point in the history
* Fixed bug with "keep_samp_for_vS" in the internal function postprocess_vS_list() which occured if keep_samp_for_vS = TRUE in the explain() function. Previously, the code did not take into consideration that vS_list became a list of lists when keep_samp_for_vS = TRUE and therfore extracted incorrect names.

* add test

---------

Co-authored-by: Martin <[email protected]>
  • Loading branch information
LHBO and martinju authored Jun 28, 2023
1 parent f94b2e7 commit 111053f
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 1 deletion.
5 changes: 4 additions & 1 deletion R/finalize_explanation.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ postprocess_vS_list <- function(vS_list, internal) {

# Appending the zero-prediction to the list
dt_vS0 <- as.data.table(rbind(c(1, rep(prediction_zero, n_explain))))
names(dt_vS0) <- names(vS_list[[1]])

# Extracting/merging the data tables from the batch running
# TODO: Need a memory and speed optimized way to transform the output form dt_vS_list to two different lists,
Expand All @@ -61,13 +60,17 @@ postprocess_vS_list <- function(vS_list, internal) {
# then there is only one copy, but there are two if keep_samp_for_vS=TRUE. This might be OK since the
# latter is used rarely
if (keep_samp_for_vS) {
names(dt_vS0) <- names(vS_list[[1]][[1]])

vS_list[[length(vS_list) + 1]] <- list(dt_vS0, NULL)

dt_vS <- rbindlist(lapply(vS_list, `[[`, 1))

dt_samp_for_vS <- rbindlist(lapply(vS_list, `[[`, 2))
data.table::setorder(dt_samp_for_vS, id_combination)
} else {
names(dt_vS0) <- names(vS_list[[1]])

vS_list[[length(vS_list) + 1]] <- dt_vS0

dt_vS <- rbindlist(vS_list)
Expand Down
10 changes: 10 additions & 0 deletions tests/testthat/_snaps/output.md
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,13 @@
2: 42.44 2.758 -3.325 -7.992 -7.12800 1.808
3: 42.44 6.805 -22.126 3.730 -0.09235 -5.885

# output_lm_numeric_independence_keep_samp_for_vS

Code
(out <- code)
Output
none Solar.R Wind Temp Month Day
1: 42.44 -4.537 8.269 17.517 -5.581 -3.066
2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971
3: 42.44 3.708 -18.610 -1.440 -2.541 1.316

Binary file not shown.
21 changes: 21 additions & 0 deletions tests/testthat/test-output.R
Original file line number Diff line number Diff line change
Expand Up @@ -496,3 +496,24 @@ test_that("output_lm_numeric_empirical_progress", {
"output_lm_numeric_empirical_progress"
)
})


# Just checking that internal$output$dt_samp_for_vS keep_samp_for_vS
test_that("output_lm_numeric_independence_keep_samp_for_vS", {
expect_snapshot_rds(
(out <- explain(
model = model_lm_numeric,
x_explain = x_explain_numeric,
x_train = x_train_numeric,
approach = "independence",
prediction_zero = p0,
n_batches = 1,
timing = FALSE,
keep_samp_for_vS = T
)),
"output_lm_numeric_independence_keep_samp_for_vS"
)

expect_false(is.null(out$internal$output$dt_samp_for_vS))
})

0 comments on commit 111053f

Please sign in to comment.