Skip to content

Commit

Permalink
Fix python version after recent updates (#412)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinju authored Oct 22, 2024
1 parent bf0780d commit 03a4206
Show file tree
Hide file tree
Showing 13 changed files with 615 additions and 175 deletions.
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ S3method(setup_approach,vaeac)
export(additional_regression_setup)
export(aicc_full_single_cpp)
export(check_convergence)
export(cli_compute_vS)
export(cli_iter)
export(cli_startup)
export(coalition_matrix_cpp)
export(compute_estimates)
export(compute_shapley_new)
Expand Down
44 changes: 42 additions & 2 deletions R/cli.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
cli_startup <- function(internal, model, verbose) {
#' Printing startup messages with cli
#'
#' @param model_class String.
#' Class of the model as a string
#' @inheritParams default_doc_explain
#' @inheritParams explain
#'
#' @export
#' @keywords internal
cli_startup <- function(internal, model_class, verbose) {
init_time <- internal$timing_list$init_time

is_groupwise <- internal$parameters$is_groupwise
Expand All @@ -20,7 +29,7 @@ cli_startup <- function(internal, model, verbose) {
confounding <- internal$parameters$confounding


line_vec <- "Model class: {.cls {class(model)}}"
line_vec <- "Model class: {.cls {model_class}}"
line_vec <- c(line_vec, "Approach: {.emph {approach}}")
line_vec <- c(line_vec, "Iterative estimation: {.emph {iterative}}")
line_vec <- c(line_vec, "Number of {.emph {feat_group_txt}} Shapley values: {n_shapley_values}")
Expand Down Expand Up @@ -63,7 +72,38 @@ cli_startup <- function(internal, model, verbose) {
}
}

#' Printing messages in compute_vS with cli
#'
#' @inheritParams default_doc_explain
#' @inheritParams explain
#'
#' @export
#' @keywords internal
cli_compute_vS <- function(internal) {

verbose <- internal$parameters$verbose
approach <- internal$parameters$approach

if ("progress" %in% verbose) {
cli::cli_progress_step("Computing vS")
}
if ("vS_details" %in% verbose) {
if ("regression_separate" %in% approach) {
tuning <- internal$parameters$regression.tune
if (isTRUE(tuning)) {
cli::cli_h2("Extra info about the tuning of the regression model")
}
}
}
}

#' Printing messages in iterative procedure with cli
#'
#' @inheritParams default_doc_explain
#' @inheritParams explain
#'
#' @export
#' @keywords internal
cli_iter <- function(verbose, internal, iter) {
iterative <- internal$parameters$iterative
asymmetric <- internal$parameters$asymmetric
Expand Down
91 changes: 46 additions & 45 deletions R/compute_vS.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,8 @@ compute_vS <- function(internal, model, predict_model, method = "future") {

S_batch <- internal$iter_list[[iter]]$S_batch

verbose <- internal$parameters$verbose
approach <- internal$parameters$approach

# verbose
if ("progress" %in% verbose) {
cli::cli_progress_step("Computing vS")
}
if ("vS_details" %in% verbose) {
if ("regression_separate" %in% approach) {
tuning <- internal$parameters$regression.tune
if (isTRUE(tuning)) {
cli::cli_h2("Extra info about the tuning of the regression model")
}
}
}

cli_compute_vS(internal)

if (method == "future") {
vS_list <- future_compute_vS_batch(
Expand All @@ -53,36 +39,7 @@ compute_vS <- function(internal, model, predict_model, method = "future") {
}

#### Adds v_S output above to any vS_list already computed ####
### Need to map the old id_coalitions to the new numbers for this merging to work out
if (iter > 1) {
prev_coalition_map <- internal$iter_list[[iter - 1]]$coalition_map
prev_vS_list <- internal$iter_list[[iter - 1]]$vS_list

current_coalition_map <- internal$iter_list[[iter]]$coalition_map


# Creates a mapper from the last id_coalition to the new id_coalition numbering
id_coalitions_mapper <- merge(prev_coalition_map,
current_coalition_map,
by = "coalitions_str",
suffixes = c("", "_new")
)
prev_vS_list_new <- list()

# Applies the mapper to update the prev_vS_list ot the new id_coalition numbering
for (k in seq_along(prev_vS_list)) {
prev_vS_list_new[[k]] <- merge(prev_vS_list[[k]],
id_coalitions_mapper[, .(id_coalition, id_coalition_new)],
by = "id_coalition"
)
prev_vS_list_new[[k]][, id_coalition := id_coalition_new]
prev_vS_list_new[[k]][, id_coalition_new := NULL]
}

# Merge the new vS_list with the old vS_list
vS_list <- c(prev_vS_list_new, vS_list)
}

vS_list <- append_vS_list(vS_list, internal)


return(vS_list)
Expand Down Expand Up @@ -264,3 +221,47 @@ compute_MCint <- function(dt, pred_cols = "p_hat") {

return(dt_mat)
}

#' Appends the new vS_list to the prev vS_list
#'
#'
#' @inheritParams compute_estimates
#'
#' @export
#' @keywords internal
append_vS_list <- function(vS_list, internal) {

iter <- length(internal$iter_list)

# Adds v_S output above to any vS_list already computed
if (iter > 1) {
prev_coalition_map <- internal$iter_list[[iter - 1]]$coalition_map
prev_vS_list <- internal$iter_list[[iter - 1]]$vS_list

# Need to map the old id_coalitions to the new numbers for this merging to work out
current_coalition_map <- internal$iter_list[[iter]]$coalition_map

# Creates a mapper from the last id_coalition to the new id_coalition numbering
id_coalitions_mapper <- merge(prev_coalition_map,
current_coalition_map,
by = "coalitions_str",
suffixes = c("", "_new")
)
prev_vS_list_new <- list()

# Applies the mapper to update the prev_vS_list ot the new id_coalition numbering
for (k in seq_along(prev_vS_list)) {
prev_vS_list_new[[k]] <- merge(prev_vS_list[[k]],
id_coalitions_mapper[, .(id_coalition, id_coalition_new)],
by = "id_coalition"
)
prev_vS_list_new[[k]][, id_coalition := id_coalition_new]
prev_vS_list_new[[k]][, id_coalition_new := NULL]
}

# Merge the new vS_list with the old vS_list
vS_list <- c(prev_vS_list_new, vS_list)
}
return(vS_list)

}
3 changes: 3 additions & 0 deletions R/documentation.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ default_doc <- function(internal, model, predict_model, output_size, extra, ...)

#' Exported documentation helper function.
#'
#' @param iter Integer.
#' The iteration number. Only used internally.
#'
#' @param internal List.
#' Not used directly, but passed through from [explain()].
#'
Expand Down
2 changes: 1 addition & 1 deletion R/explain.R
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ explain <- function(model,
set.seed(seed)
}

cli_startup(internal, model, verbose)
cli_startup(internal, class(model), verbose)


while (converged == FALSE) {
Expand Down
2 changes: 1 addition & 1 deletion R/explain_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ explain_forecast <- function(model,
set.seed(seed)
}

cli_startup(internal, model, verbose)
cli_startup(internal, class(model), verbose)

while (converged == FALSE) {
cli_iter(verbose, internal, iter)
Expand Down
2 changes: 1 addition & 1 deletion R/plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ plot.shapr <- function(x,
desc_mat[, i] <- paste0(shap_names[i], " = ", desc_mat[, i])
}
} else {
desc_mat <- trimws(format(x$shapley_values_est[, -c("explain_id", "none")], digits = digits))
desc_mat <- trimws(format(x$shapley_values_est[, -c("none")], digits = digits))
for (i in seq_len(ncol(desc_mat))) {
desc_mat[, i] <- paste0(shap_names[i])
}
Expand Down
16 changes: 16 additions & 0 deletions man/cli_compute_vS.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 33 additions & 0 deletions man/cli_iter.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 36 additions & 0 deletions man/cli_startup.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 03a4206

Please sign in to comment.