Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructured torch modules to support shapr installation without torch #393

Merged
merged 12 commits into from
Apr 18, 2024
22 changes: 16 additions & 6 deletions R/approach_vaeac.R
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,11 @@ vaeac_train_model <- function(x_train,
best_vlb <- -Inf

# Create a `progressr::progressor()` to keep track of the overall training time of the vaeac approach
progressr_bar <- progressr::progressor(steps = epochs_initiation_phase * (n_vaeacs_initialize - 1) + epochs)
if (requireNamespace("progressr", quietly = TRUE)) {
progressr_bar <- progressr::progressor(steps = epochs_initiation_phase * (n_vaeacs_initialize - 1) + epochs)
} else {
progressr_bar <- NULL
}

# Iterate over the initializations.
initialization_idx <- 1
Expand Down Expand Up @@ -835,9 +839,10 @@ vaeac_train_model_continue <- function(explanation,
# Set seed for reproducibility
set.seed(seed)

# Extract the vaeac list and load the model at the last epoch
# Extract the vaeac list and load the model at the last epoch or the best (default 'best' when path is provided)
vaeac_model <- explanation$internal$parameters$vaeac
checkpoint <- torch::torch_load(vaeac_model$models$last)
vaeac_model_path <- if (!is.null(vaeac_model$models$last)) vaeac_model$models$last else vaeac_model$models$best
checkpoint <- torch::torch_load(vaeac_model_path)

# Get which device we are to continue to train the model
device <- ifelse(checkpoint$cuda, "cuda", "cpu")
Expand Down Expand Up @@ -939,7 +944,11 @@ vaeac_train_model_continue <- function(explanation,
state_list$epochs <- epochs

# Create a `progressr::progressor()` to keep track of the new training
progressr_bar <- progressr::progressor(steps = epochs_new)
if (requireNamespace("progressr", quietly = TRUE)) {
progressr_bar <- progressr::progressor(steps = epochs_new)
} else {
progressr_bar <- NULL
}

# Train the vaeac model for `epochs_new` number of epochs
vaeac_tmp <- vaeac_train_model_auxiliary(
Expand Down Expand Up @@ -1617,8 +1626,9 @@ vaeac_check_parameters <- function(x_train,
#' then a name will be generated based on [base::Sys.time()] to ensure a unique name. We use [base::make.names()] to
#' ensure a valid file name for all operating systems.
#' @param vaeac.folder_to_save_model String (default is [base::tempdir()]). String specifying a path to a folder where
#' the function is to save the fitted vaeac model. Note that the path will be removed from the returned
#' [shapr::explain()] object if `vaeac.save_model = FALSE`.
#' the function is to save the fitted vaeac model. Note that the path will be removed from the returned
#' [shapr::explain()] object if `vaeac.save_model = FALSE`. Furthermore, the model cannot be moved from its
#' original folder if we are to use the [shapr::vaeac_train_model_continue()] function to continue training the model.
#' @param vaeac.pretrained_vaeac_model List or String (default is `NULL`). 1) Either a list of class
#' `vaeac`, i.e., the list stored in `explanation$internal$parameters$vaeac` where `explanation` is the returned list
#' from an earlier call to the [shapr::explain()] function. 2) A string containing the path to where the `vaeac`
Expand Down
Loading
Loading