From 6e3dba37df70e1311693fc19d587f32836edf830 Mon Sep 17 00:00:00 2001 From: Lars van der Laan Date: Wed, 17 Mar 2021 16:00:51 -0700 Subject: [PATCH 01/10] Fix pooled_hazard_task bug Pooled hazards task does not work as intended (or errors) when using tasks that have non-null row_index internal variables. This occurs because when passing in a new dataset to next_in_chain, it does not reset the internal row_index variable. This makes using CV learners that utilize pooled hazards tasks internally (e.g. Lrnr_cv$new(Lrnr_pooled_hazards$new())) break down. For tmle3 survival, the pooled hazards task was created externally as the main task, so the hazard estimation with CV was not impacted by this bug. It might also be worth changing sl3_Task so that row_index is reset when passing in a new dataset. Or, maybe we should just not allow datasets to be passed in through next_in_chain. --- R/survival_utils.R | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/R/survival_utils.R b/R/survival_utils.R index dd11429c..e7af4e30 100644 --- a/R/survival_utils.R +++ b/R/survival_utils.R @@ -32,11 +32,17 @@ pooled_hazard_task <- function(task, trim = TRUE) { repeated_data <- underlying_data[index, ] new_folds <- origami::id_folds_to_folds(task$folds, index) - repeated_task <- task$next_in_chain( - column_names = column_names, - data = repeated_data, id = "id", - folds = new_folds - ) + rnodes <- task$nodes + nodes$id <- "id" + repeated_task <- sl3_Task$new(repeated_data, column_names = column_names, nodes = task$nodes, folds = new_folds, outcome_levels = outcome_levels, outcome_type = task$outcome_type) + # If "task" has a non-null row_index then this will fail. + # The next_in_chain function does not reset the row_index if data is passed in. + # So CV learners and pooled hazards don't work + # repeated_task <- task$next_in_chain( + # column_names = column_names, + # data = repeated_data, id = "id", + # folds = new_folds, row_index = NULL + # ) # make bin indicators bin_number <- rep(level_index, each = task$nrow) From bd7e62628914b008e0386b2523461012a3ed6a41 Mon Sep 17 00:00:00 2001 From: Lars van der Laan Date: Wed, 17 Mar 2021 16:33:43 -0700 Subject: [PATCH 02/10] Update Lrnr_hal9001.R Cheap fix for Lrnr_hal to work with latest version of hal9001/devel --- R/Lrnr_hal9001.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/Lrnr_hal9001.R b/R/Lrnr_hal9001.R index 6bbc31f6..cfcf58f5 100644 --- a/R/Lrnr_hal9001.R +++ b/R/Lrnr_hal9001.R @@ -121,7 +121,7 @@ Lrnr_hal9001 <- R6Class( args$id <- task$id } - fit_object <- call_with_args(hal9001::fit_hal, args) + fit_object <- call_with_args(hal9001::fit_hal, args, keep_all = TRUE) return(fit_object) }, .predict = function(task = NULL) { From 7c0054f365b2cbfbacbc769549adf5930d2b7927 Mon Sep 17 00:00:00 2001 From: Lars van der Laan Date: Wed, 17 Mar 2021 16:54:59 -0700 Subject: [PATCH 03/10] Update survival_utils.R Fixed type --- R/survival_utils.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/survival_utils.R b/R/survival_utils.R index e7af4e30..65fb0dec 100644 --- a/R/survival_utils.R +++ b/R/survival_utils.R @@ -32,9 +32,9 @@ pooled_hazard_task <- function(task, trim = TRUE) { repeated_data <- underlying_data[index, ] new_folds <- origami::id_folds_to_folds(task$folds, index) - rnodes <- task$nodes + nodes <- task$nodes nodes$id <- "id" - repeated_task <- sl3_Task$new(repeated_data, column_names = column_names, nodes = task$nodes, folds = new_folds, outcome_levels = outcome_levels, outcome_type = task$outcome_type) + repeated_task <- sl3_Task$new(repeated_data, column_names = column_names, nodes = task$nodes, folds = new_folds, outcome_levels = outcome_levels, outcome_type = task$outcome_type$type) # If "task" has a non-null row_index then this will fail. # The next_in_chain function does not reset the row_index if data is passed in. # So CV learners and pooled hazards don't work From eb9e2ee90dbfc381060a95ea2f9ab6f90fd7a0bc Mon Sep 17 00:00:00 2001 From: Lars van der Laan Date: Wed, 17 Mar 2021 17:10:37 -0700 Subject: [PATCH 04/10] Squash hal --- R/Lrnr_hal9001.R | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/R/Lrnr_hal9001.R b/R/Lrnr_hal9001.R index cfcf58f5..74a5a3d3 100644 --- a/R/Lrnr_hal9001.R +++ b/R/Lrnr_hal9001.R @@ -88,6 +88,7 @@ Lrnr_hal9001 <- R6Class( return_x_basis = FALSE, basis_list = NULL, cv_select = TRUE, + squash = TRUE, ...) { params <- args_to_list() super$initialize(params = params, ...) @@ -122,6 +123,9 @@ Lrnr_hal9001 <- R6Class( } fit_object <- call_with_args(hal9001::fit_hal, args, keep_all = TRUE) + if(self$params$squash) { + fit_object <- hal9001::squash_hal_fit(fit_object) + } return(fit_object) }, .predict = function(task = NULL) { From 6e1748a14f8b3c1fa119ac82c8fa36cc0119f91d Mon Sep 17 00:00:00 2001 From: Lars van der Laan Date: Wed, 17 Mar 2021 17:12:32 -0700 Subject: [PATCH 05/10] Update Lrnr_hal9001.R --- R/Lrnr_hal9001.R | 2 ++ 1 file changed, 2 insertions(+) diff --git a/R/Lrnr_hal9001.R b/R/Lrnr_hal9001.R index 74a5a3d3..0229144d 100644 --- a/R/Lrnr_hal9001.R +++ b/R/Lrnr_hal9001.R @@ -70,6 +70,8 @@ #' to \code{TRUE}) or to fit along the sequence of values (or a single value #' using \code{\link[glmnet]{glmnet}} (when set to \code{FALSE}). #' } +#' \item{\code{squash=TRUE}}{A \code{logical} specifying whether to call \code{\link[hal9001]{squash_hal_fit}} on the returned hal9001 fit object. +#' } #' \item{\code{...}}{Other parameters passed directly to #' \code{\link[hal9001]{fit_hal}}. See its documentation for details. #' } From 75a6552505193ba800ab7fa85fd71d03a34e9110 Mon Sep 17 00:00:00 2001 From: Lars van der Laan Date: Thu, 18 Mar 2021 14:22:09 -0700 Subject: [PATCH 06/10] HAL add argument --- R/Lrnr_hal9001.R | 3 +- man/Lrnr_hal9001.Rd | 2 + vignettes/testing.Rmd | 89 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 93 insertions(+), 1 deletion(-) create mode 100644 vignettes/testing.Rmd diff --git a/R/Lrnr_hal9001.R b/R/Lrnr_hal9001.R index 0229144d..f964d016 100644 --- a/R/Lrnr_hal9001.R +++ b/R/Lrnr_hal9001.R @@ -91,6 +91,7 @@ Lrnr_hal9001 <- R6Class( basis_list = NULL, cv_select = TRUE, squash = TRUE, + p_reserve = 0.5, ...) { params <- args_to_list() super$initialize(params = params, ...) @@ -131,7 +132,7 @@ Lrnr_hal9001 <- R6Class( return(fit_object) }, .predict = function(task = NULL) { - predictions <- predict(self$fit_object, new_data = as.matrix(task$X)) + predictions <- predict(self$fit_object, new_data = as.matrix(task$X), p_reserve = self$params$p_reserve) if (!is.na(safe_dim(predictions)[2])) { p <- ncol(predictions) colnames(predictions) <- sprintf("lambda_%0.3e", self$params$lambda) diff --git a/man/Lrnr_hal9001.Rd b/man/Lrnr_hal9001.Rd index f66d40d4..9d8c8203 100644 --- a/man/Lrnr_hal9001.Rd +++ b/man/Lrnr_hal9001.Rd @@ -69,6 +69,8 @@ in order to pick the optimal value (based on cross-validation) (when set to \code{TRUE}) or to fit along the sequence of values (or a single value using \code{\link[glmnet]{glmnet}} (when set to \code{FALSE}). } +\item{\code{squash=TRUE}}{A \code{logical} specifying whether to call \code{\link[hal9001]{squash_hal_fit}} on the returned hal9001 fit object. +} \item{\code{...}}{Other parameters passed directly to \code{\link[hal9001]{fit_hal}}. See its documentation for details. } diff --git a/vignettes/testing.Rmd b/vignettes/testing.Rmd new file mode 100644 index 00000000..2bc812f5 --- /dev/null +++ b/vignettes/testing.Rmd @@ -0,0 +1,89 @@ +--- +title: "Untitled" +output: html_document +--- + +```{r setup, include=FALSE} +knitr::opts_chunk$set(echo = TRUE) +``` + +```{r} +library(sl3) + + +``` + +```{r} +n <- 2500 +library(simcausal) +#library(sl3) +D <- DAG.empty() + + +D <- D + + node("W1f", distr = "runif", min = -1, max = 1) + + node("W2f", distr = "runif", min = -1, max = 1) + + node("W3f", distr = "runif", min = -1, max = 1) + + node("W1", distr = "rconst", const = W1f) + + node("W2", distr = "rconst", const = W2f) + + node("W3", distr = "rconst", const = W3f) + + node("g", distr = "rconst", const = 0.2 + 0.65*plogis(sin(W1*5) + W1*sin(W1*5) + cos(W2*5) + 2*W1*W2 - sin(W3*5) + sin(5*W1*W3) + 2*W1*W2*W3 + W3*sin(W1*5) + cos(W2*4)*sin(W1*5) ) ) + + node("A", distr = "rbinom", size = 1, prob = g )+ + node("gR", distr = "rconst", const = 2*(W1 + W2 + W3) + A*(W1 + W2 + W3 + W1*W2 + W2*W3 + W1*W3 ) + W1*W2 + W2*W3 + W1*W3 + W1^2 -W2^2 + W3^2 ) + + node("R", distr = "rnorm", mean = gR, sd = 1) + +setD <- set.DAG(D) +data <- sim(setD, n = n) +data + +``` + + +```{r} +#call_with_args <- sl3:::call_with_args +library(R6) +task <- sl3_Task$new(data, covariates = c("W1", "W2", "W3", "A"), outcome = "R") +task$data + +lrnr_ranger <- Lrnr_ranger$new(num.trees = 50, predict.all = TRUE ) +lrnr_ranger <- lrnr_ranger$train(task) +data.table::as.data.table(lrnr_ranger$predict(task)) +``` + + + +```{r} +lrnr_xgboost <- Lrnr_xgboost$new(nrounds = 20, predict.all.trees = FALSE ) +lrnr_xgboost <- lrnr_xgboost$train(task) +data.table::as.data.table(lrnr_xgboost$predict(task)) + +lrnr_xgboost_stacked <- make_learner(Pipeline, Lrnr_cv$new(Lrnr_xgboost$new(nrounds = 20, predict.all.trees = TRUE )), Lrnr_nnls$new(convex = FALSE)) +lrnr_xgboost_stacked <- lrnr_xgboost_stacked$train(task) +data.table::as.data.table(lrnr_xgboost_stacked$predict(task)) +``` + + +```{r} + +lrnr_xg_stack <- make_learner(Stack, Lrnr_xgboost$new(nrounds = 20, predict.all.rounds = TRUE, max_depth = 3 ), Lrnr_xgboost$new(nrounds = 20, predict.all.rounds = TRUE, max_depth = 5 ), + Lrnr_xgboost$new(nrounds = 20, predict.all.rounds = TRUE, max_depth = 7 ), + Lrnr_xgboost$new(nrounds = 20, predict.all.rounds = TRUE, max_depth = 10 )) +lrnr_xg_stack <- Lrnr_sl$new(lrnr_xg_stack, metalearner = Lrnr_$new()) +``` +```{r} +lrnr_stack <- make_learner(Stack, + Lrnr_xgboost$new(nrounds = 20, predict.all.trees = FALSE, max_depth = 3 ), + Lrnr_xgboost$new(nrounds = 20, predict.all.trees = FALSE, max_depth = 5 ), + Lrnr_xgboost$new(nrounds = 20, predict.all.trees = FALSE, max_depth = 7 ), Lrnr_xgboost$new(nrounds = 20, predict.all.trees = FALSE, max_depth = 10 ), lrnr_xg_stack) +lrnr_stack <- lrnr_stack$train(task) +preds <- lrnr_stack$predict(task) +as.data.frame(apply(preds - data$gR, 2, function(v) {mean(v^2)})) +#lrnr_cv <- Lrnr_cv$new(lrnr_stack) +#lrnr_cv <- lrnr_cv$train(task) +#lrnr_cv$cv_risk(loss_squared_error) +``` + + + + + From f3fe6eb2f60e2693aee6b2d94c9f0430e996c492 Mon Sep 17 00:00:00 2001 From: Lars van der Laan Date: Sat, 20 Mar 2021 17:08:39 -0700 Subject: [PATCH 07/10] hal changes --- R/Lrnr_hal9001.R | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/R/Lrnr_hal9001.R b/R/Lrnr_hal9001.R index f964d016..91a2ef8e 100644 --- a/R/Lrnr_hal9001.R +++ b/R/Lrnr_hal9001.R @@ -92,6 +92,7 @@ Lrnr_hal9001 <- R6Class( cv_select = TRUE, squash = TRUE, p_reserve = 0.5, + formula = NULL, ...) { params <- args_to_list() super$initialize(params = params, ...) @@ -124,8 +125,13 @@ Lrnr_hal9001 <- R6Class( if (task$has_node("id")) { args$id <- task$id } - - fit_object <- call_with_args(hal9001::fit_hal, args, keep_all = TRUE) + if(!is.null(self$params$formula)) { + args$data <- task$data + formula <- call_with_args(hal9001::formula_hal(self$params$formula, args), keep_all = TRUE) + fit_object <- hal9001::fit_hal_formula(formula) + } else { + fit_object <- call_with_args(hal9001::fit_hal, args, keep_all = TRUE) + } if(self$params$squash) { fit_object <- hal9001::squash_hal_fit(fit_object) } From a4748de63c0c094c35dd73655fe9148f0ac823d7 Mon Sep 17 00:00:00 2001 From: Lars van der Laan Date: Sat, 20 Mar 2021 17:14:28 -0700 Subject: [PATCH 08/10] hal changes --- R/Lrnr_hal9001.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/Lrnr_hal9001.R b/R/Lrnr_hal9001.R index 91a2ef8e..2ef08a09 100644 --- a/R/Lrnr_hal9001.R +++ b/R/Lrnr_hal9001.R @@ -127,7 +127,7 @@ Lrnr_hal9001 <- R6Class( } if(!is.null(self$params$formula)) { args$data <- task$data - formula <- call_with_args(hal9001::formula_hal(self$params$formula, args), keep_all = TRUE) + formula <- call_with_args(hal9001::formula_hal, args, keep_all = TRUE) fit_object <- hal9001::fit_hal_formula(formula) } else { fit_object <- call_with_args(hal9001::fit_hal, args, keep_all = TRUE) From b6cd084a5550ffc5491ca409c181bfdfa29296fa Mon Sep 17 00:00:00 2001 From: Lars van der Laan Date: Sat, 20 Mar 2021 17:19:04 -0700 Subject: [PATCH 09/10] hal changes --- R/Lrnr_hal9001.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/Lrnr_hal9001.R b/R/Lrnr_hal9001.R index 2ef08a09..56a8dcc2 100644 --- a/R/Lrnr_hal9001.R +++ b/R/Lrnr_hal9001.R @@ -127,7 +127,7 @@ Lrnr_hal9001 <- R6Class( } if(!is.null(self$params$formula)) { args$data <- task$data - formula <- call_with_args(hal9001::formula_hal, args, keep_all = TRUE) + formula <- call_with_args(hal9001::formula_hal, args, ignore = c("X", "Y")) fit_object <- hal9001::fit_hal_formula(formula) } else { fit_object <- call_with_args(hal9001::fit_hal, args, keep_all = TRUE) From 77dce977eae7a6968ba86289592ab3976cebd2cd Mon Sep 17 00:00:00 2001 From: Lars van der Laan Date: Tue, 15 Mar 2022 18:30:01 -0700 Subject: [PATCH 10/10] Update Lrnr_hal9001.R --- R/Lrnr_hal9001.R | 167 +++++++++++++++++++---------------------------- 1 file changed, 66 insertions(+), 101 deletions(-) diff --git a/R/Lrnr_hal9001.R b/R/Lrnr_hal9001.R index 56a8dcc2..9cdd02aa 100644 --- a/R/Lrnr_hal9001.R +++ b/R/Lrnr_hal9001.R @@ -1,150 +1,115 @@ -#' The Scalable Highly Adaptive Lasso +#' Scalable Highly Adaptive Lasso (HAL) #' -#' The Highly Adaptive Lasso is an estimation procedure that generates a design -#' matrix consisting of basis functions corresponding to covariates and -#' interactions of covariates and fits Lasso regression to this (usually) very -#' wide matrix, recovering a nonparametric functional form that describes the -#' target prediction function as a composition of subset functions with finite -#' variation norm. This implementation uses \pkg{hal9001}, which provides both -#' a custom implementation (based on \pkg{origami}) of the cross-validated -#' lasso as well the standard call to \code{\link[glmnet]{cv.glmnet}} from the -#' \pkg{glmnet}. +#' The Highly Adaptive Lasso (HAL) is a nonparametric regression function that +#' has been demonstrated to optimally estimate functions with bounded (finite) +#' variation norm. The algorithm proceeds by first building an adaptive basis +#' (i.e., the HAL basis) based on indicator basis functions (or higher-order +#' spline basis functions) representing covariates and interactions of the +#' covariates up to a pre-specified degree. The fitting procedures included in +#' this learner use \code{\link[hal9001]{fit_hal}} from the \pkg{hal9001} +#' package. For details on HAL regression, consider consulting the following +#' \insertCite{benkeser2016hal;textual}{sl3}), +#' \insertCite{coyle2020hal9001-rpkg;textual}{sl3}), +#' \insertCite{hejazi2020hal9001-joss;textual}{sl3}). #' #' @docType class +#' #' @importFrom R6 R6Class +#' @importFrom origami folds2foldvec +#' @importFrom stats predict quasibinomial #' #' @export #' #' @keywords data #' -#' @return Learner object with methods for training and prediction. See -#' \code{\link{Lrnr_base}} for documentation on learners. +#' @return A learner object inheriting from \code{\link{Lrnr_base}} with +#' methods for training and prediction. For a full list of learner +#' functionality, see the complete documentation of \code{\link{Lrnr_base}}. #' -#' @format \code{\link{R6Class}} object. +#' @format An \code{\link[R6]{R6Class}} object inheriting from +#' \code{\link{Lrnr_base}}. #' #' @family Learners #' #' @section Parameters: -#' \describe{ -#' \item{\code{max_degree=3}}{ The highest order of interaction -#' terms for which the basis functions ought to be generated. The default -#' corresponds to generating basis functions up to all 3-way interactions of -#' covariates in the input matrix, matching the default in \pkg{hal9001}. -#' } -#' \item{\code{fit_type="glmnet"}}{The specific routine to be called when -#' fitting the Lasso regression in a cross-validated manner. Choosing the -#' \code{"glmnet"} option calls either \code{\link[glmnet]{cv.glmnet}} or -#' \code{\link[glmnet]{glmnet}}. -#' } -#' \item{\code{n_folds=10}}{Integer for the number of folds to be used -#' when splitting the data for cross-validation. This defaults to 10 as this -#' is the convention for V-fold cross-validation. -#' } -#' \item{\code{use_min=TRUE}}{Determines which lambda is selected from -#' \code{\link[glmnet]{cv.glmnet}}. \code{TRUE} corresponds to -#' \code{"lambda.min"} and \code{FALSE} corresponds to \code{"lambda.1se"}. -#' } -#' \item{\code{reduce_basis=NULL}}{A \code{numeric} value bounded in the open -#' interval (0,1) indicating the minimum proportion of ones in a basis -#' function column needed for the basis function to be included in the -#' procedure to fit the Lasso. Any basis functions with a lower proportion -#' of 1's than the specified cutoff will be removed. This argument defaults -#' to \code{NULL}, in which case all basis functions are used in the Lasso -#' stage of HAL. -#' } -#' \item{\code{return_lasso=TRUE}}{A \code{logical} indicating whether or not -#' to return the \code{\link[glmnet]{glmnet}} fit of the Lasso model. -#' } -#' \item{\code{return_x_basis=FALSE}}{A \code{logical} indicating whether or -#' not to return the matrix of (possibly reduced) basis functions used in -#' the HAL Lasso fit. -#' } -#' \item{\code{basis_list=NULL}}{The full set of basis functions generated -#' from the input data (from \code{\link[hal9001]{enumerate_basis}}). The -#' dimensionality of this structure is roughly (n * 2^(d - 1)), where n is -#' the number of observations and d is the number of columns in the input. -#' } -#' \item{\code{cv_select=TRUE}}{A \code{logical} specifying whether the array -#' of values specified should be passed to \code{\link[glmnet]{cv.glmnet}} -#' in order to pick the optimal value (based on cross-validation) (when set -#' to \code{TRUE}) or to fit along the sequence of values (or a single value -#' using \code{\link[glmnet]{glmnet}} (when set to \code{FALSE}). -#' } -#' \item{\code{squash=TRUE}}{A \code{logical} specifying whether to call \code{\link[hal9001]{squash_hal_fit}} on the returned hal9001 fit object. -#' } -#' \item{\code{...}}{Other parameters passed directly to -#' \code{\link[hal9001]{fit_hal}}. See its documentation for details. -#' } -#' } -# +#' - \code{...}: Arguments passed to \code{\link[hal9001]{fit_hal}}. See +#' it's documentation for details. +#' +#' @examples +#' data(cpp_imputed) +#' covs <- c("apgar1", "apgar5", "parity", "gagebrth", "mage", "meducyrs") +#' task <- sl3_Task$new(cpp_imputed, covariates = covs, outcome = "haz") +#' +#' # instantiate with max 2-way interactions, 0-order splines, and binning +#' # (i.e., num_knots) that decreases with increasing interaction degree +#' hal_lrnr <- Lrnr_hal9001$new( +#' max_degree = 2, num_knots = c(20, 10), smoothness_orders = 0 +#' ) +#' hal_fit <- hal_lrnr$train(task) +#' hal_preds <- hal_fit$predict() Lrnr_hal9001 <- R6Class( - classname = "Lrnr_hal9001", inherit = Lrnr_base, - portable = TRUE, class = TRUE, + classname = "Lrnr_hal9001", + inherit = Lrnr_base, portable = TRUE, class = TRUE, public = list( - initialize = function(max_degree = 3, - fit_type = "glmnet", - n_folds = 10, - use_min = TRUE, - reduce_basis = NULL, - return_lasso = TRUE, - return_x_basis = FALSE, - basis_list = NULL, - cv_select = TRUE, - squash = TRUE, - p_reserve = 0.5, - formula = NULL, - ...) { + initialize = function(...) { params <- args_to_list() super$initialize(params = params, ...) } ), private = list( .properties = c("continuous", "binomial", "weights", "ids"), - .train = function(task) { args <- self$params + args$X <- as.matrix(task$X) + outcome_type <- self$get_outcome_type(task) + args$Y <- outcome_type$format(task$Y) if (is.null(args$family)) { - args$family <- args$family <- outcome_type$glm_family() + args$family <- outcome_type$glm_family() } - args$X <- as.matrix(task$X) - args$Y <- outcome_type$format(task$Y) - args$yolo <- FALSE + if (!any(grepl("fit_control", names(args)))) { + args$fit_control <- list() + } + args$fit_control$foldid <- origami::folds2foldvec(task$folds) + + if (task$has_node("id")) { + args$id <- task$id + } if (task$has_node("weights")) { - args$weights <- task$weights + args$fit_control$weights <- task$weights } if (task$has_node("offset")) { args$offset <- task$offset } - if (task$has_node("id")) { - args$id <- task$id - } - if(!is.null(self$params$formula)) { - args$data <- task$data - formula <- call_with_args(hal9001::formula_hal, args, ignore = c("X", "Y")) - fit_object <- hal9001::fit_hal_formula(formula) - } else { - fit_object <- call_with_args(hal9001::fit_hal, args, keep_all = TRUE) - } - if(self$params$squash) { - fit_object <- hal9001::squash_hal_fit(fit_object) - } + # fit HAL, allowing glmnet-fitting arguments + other_valid <- c( + names(formals(glmnet::cv.glmnet)), names(formals(glmnet::glmnet)) + ) + + fit_object <- call_with_args( + hal9001::fit_hal, args, + other_valid = other_valid + ) + return(fit_object) }, .predict = function(task = NULL) { - predictions <- predict(self$fit_object, new_data = as.matrix(task$X), p_reserve = self$params$p_reserve) + predictions <- stats::predict( + self$fit_object, + new_data = data.matrix(task$X) + ) if (!is.na(safe_dim(predictions)[2])) { p <- ncol(predictions) colnames(predictions) <- sprintf("lambda_%0.3e", self$params$lambda) } return(predictions) }, - .required_packages = c("hal9001") + .required_packages = c("hal9001", "glmnet") ) )