diff --git a/R/benchmark.R b/R/benchmark.R index 5f26c38a6..07e862d65 100644 --- a/R/benchmark.R +++ b/R/benchmark.R @@ -18,6 +18,7 @@ #' @template param_allow_hotstart #' @template param_clone #' @template param_unmarshal +#' @template param_callbacks #' #' @return [BenchmarkResult]. #' @@ -81,7 +82,7 @@ #' ## Get the training set of the 2nd iteration of the featureless learner on penguins #' rr = bmr$aggregate()[learner_id == "classif.featureless"]$resample_result[[1]] #' rr$resampling$train_set(2) -benchmark = function(design, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling"), unmarshal = TRUE) { +benchmark = function(design, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling"), unmarshal = TRUE, callbacks = NULL) { assert_subset(clone, c("task", "learner", "resampling")) assert_data_frame(design, min.rows = 1L) assert_names(names(design), must.include = c("task", "learner", "resampling")) @@ -96,6 +97,7 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps } assert_flag(store_models) assert_flag(store_backends) + callbacks = assert_callbacks(as_callbacks(callbacks)) # check for multiple task types task_types = unique(map_chr(design$task, "task_type")) @@ -187,14 +189,15 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps res = future_map(n, workhorse, task = grid$task, learner = grid$learner, resampling = grid$resampling, iteration = grid$iteration, param_values = grid$param_values, mode = grid$mode, - MoreArgs = list(store_models = store_models, lgr_threshold = lgr_threshold, pb = pb, unmarshal = unmarshal) + MoreArgs = list(store_models = store_models, lgr_threshold = lgr_threshold, pb = pb, unmarshal = unmarshal, callbacks = callbacks) ) grid = insert_named(grid, list( learner_state = map(res, "learner_state"), prediction = map(res, "prediction"), param_values = map(res, "param_values"), - learner_hash = map_chr(res, "learner_hash") + learner_hash = map_chr(res, "learner_hash"), + data_extra = map(res, "data_extra") )) lg$info("Finished benchmark") diff --git a/R/resample.R b/R/resample.R index 4cbcc2d64..da368373a 100644 --- a/R/resample.R +++ b/R/resample.R @@ -15,6 +15,7 @@ #' @template param_allow_hotstart #' @template param_clone #' @template param_unmarshal +#' @template param_callbacks #' @return [ResampleResult]. #' #' @template section_predict_sets @@ -67,8 +68,6 @@ resample = function( unmarshal = TRUE, callbacks = NULL ) { - callbacks = assert_callbacks(as_callbacks(callbacks)) - assert_subset(clone, c("task", "learner", "resampling")) task = assert_task(as_task(task, clone = "task" %in% clone)) learner = assert_learner(as_learner(learner, clone = "learner" %in% clone, discard_state = TRUE)) @@ -78,6 +77,7 @@ resample = function( # this does not check the internal validation task as it might not be set yet assert_learnable(task, learner) assert_flag(unmarshal) + callbacks = assert_callbacks(as_callbacks(callbacks)) set_encapsulation(list(learner), encapsulate) if (!resampling$is_instantiated) { diff --git a/man-roxygen/param_callbacks.R b/man-roxygen/param_callbacks.R new file mode 100644 index 000000000..cdb286953 --- /dev/null +++ b/man-roxygen/param_callbacks.R @@ -0,0 +1,3 @@ +#' @param callbacks (List of [mlr3misc::Callback])\cr +#' Callbacks to be executed during the resampling process. +#' See [CallbackEvaluation] and [ContextEvaluation] for details. diff --git a/man/benchmark.Rd b/man/benchmark.Rd index 9cfc995f7..9f53cecfd 100644 --- a/man/benchmark.Rd +++ b/man/benchmark.Rd @@ -11,7 +11,8 @@ benchmark( encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling"), - unmarshal = TRUE + unmarshal = TRUE, + callbacks = NULL ) } \arguments{ @@ -63,6 +64,10 @@ Per default, all input objects are cloned.} Whether to unmarshal learners that were marshaled during the execution. If \code{TRUE} all models are stored in unmarshaled form. If \code{FALSE}, all learners (that need marshaling) are stored in marshaled form.} + +\item{callbacks}{(List of \link[mlr3misc:Callback]{mlr3misc::Callback})\cr +Callbacks to be executed during the resampling process. +See \link{CallbackEvaluation} and \link{ContextEvaluation} for details.} } \value{ \link{BenchmarkResult}. diff --git a/man/resample.Rd b/man/resample.Rd index 00378067e..41340d299 100644 --- a/man/resample.Rd +++ b/man/resample.Rd @@ -65,6 +65,10 @@ Per default, all input objects are cloned.} Whether to unmarshal learners that were marshaled during the execution. If \code{TRUE} all models are stored in unmarshaled form. If \code{FALSE}, all learners (that need marshaling) are stored in marshaled form.} + +\item{callbacks}{(List of \link[mlr3misc:Callback]{mlr3misc::Callback})\cr +Callbacks to be executed during the resampling process. +See \link{CallbackEvaluation} and \link{ContextEvaluation} for details.} } \value{ \link{ResampleResult}.