Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Nov 27, 2024
1 parent 01e6a4b commit ae85cc0
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 6 deletions.
9 changes: 6 additions & 3 deletions R/benchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#' @template param_allow_hotstart
#' @template param_clone
#' @template param_unmarshal
#' @template param_callbacks
#'
#' @return [BenchmarkResult].
#'
Expand Down Expand Up @@ -81,7 +82,7 @@
#' ## Get the training set of the 2nd iteration of the featureless learner on penguins
#' rr = bmr$aggregate()[learner_id == "classif.featureless"]$resample_result[[1]]
#' rr$resampling$train_set(2)
benchmark = function(design, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling"), unmarshal = TRUE) {
benchmark = function(design, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling"), unmarshal = TRUE, callbacks = NULL) {
assert_subset(clone, c("task", "learner", "resampling"))
assert_data_frame(design, min.rows = 1L)
assert_names(names(design), must.include = c("task", "learner", "resampling"))
Expand All @@ -96,6 +97,7 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps
}
assert_flag(store_models)
assert_flag(store_backends)
callbacks = assert_callbacks(as_callbacks(callbacks))

# check for multiple task types
task_types = unique(map_chr(design$task, "task_type"))
Expand Down Expand Up @@ -187,14 +189,15 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps

res = future_map(n, workhorse,
task = grid$task, learner = grid$learner, resampling = grid$resampling, iteration = grid$iteration, param_values = grid$param_values, mode = grid$mode,
MoreArgs = list(store_models = store_models, lgr_threshold = lgr_threshold, pb = pb, unmarshal = unmarshal)
MoreArgs = list(store_models = store_models, lgr_threshold = lgr_threshold, pb = pb, unmarshal = unmarshal, callbacks = callbacks)
)

grid = insert_named(grid, list(
learner_state = map(res, "learner_state"),
prediction = map(res, "prediction"),
param_values = map(res, "param_values"),
learner_hash = map_chr(res, "learner_hash")
learner_hash = map_chr(res, "learner_hash"),
data_extra = map(res, "data_extra")
))

lg$info("Finished benchmark")
Expand Down
4 changes: 2 additions & 2 deletions R/resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#' @template param_allow_hotstart
#' @template param_clone
#' @template param_unmarshal
#' @template param_callbacks
#' @return [ResampleResult].
#'
#' @template section_predict_sets
Expand Down Expand Up @@ -67,8 +68,6 @@ resample = function(
unmarshal = TRUE,
callbacks = NULL
) {
callbacks = assert_callbacks(as_callbacks(callbacks))

assert_subset(clone, c("task", "learner", "resampling"))
task = assert_task(as_task(task, clone = "task" %in% clone))
learner = assert_learner(as_learner(learner, clone = "learner" %in% clone, discard_state = TRUE))
Expand All @@ -78,6 +77,7 @@ resample = function(
# this does not check the internal validation task as it might not be set yet
assert_learnable(task, learner)
assert_flag(unmarshal)
callbacks = assert_callbacks(as_callbacks(callbacks))

set_encapsulation(list(learner), encapsulate)
if (!resampling$is_instantiated) {
Expand Down
3 changes: 3 additions & 0 deletions man-roxygen/param_callbacks.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#' @param callbacks (List of [mlr3misc::Callback])\cr
#' Callbacks to be executed during the resampling process.
#' See [CallbackEvaluation] and [ContextEvaluation] for details.
7 changes: 6 additions & 1 deletion man/benchmark.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions man/resample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit ae85cc0

Please sign in to comment.