From 8fb38187601b16be85539cae93d3a8fc6a8e6a03 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Thu, 25 Jan 2024 13:05:27 +0100 Subject: [PATCH] feat(Learner): support bundling property This is useful for learners that e.g. rely on external pointers and that need so processing of the model to be serializable. --- DESCRIPTION | 2 +- NEWS.md | 4 ++ R/Learner.R | 48 ++++++++++++++++++- R/benchmark.R | 5 +- R/mlr_reflections.R | 2 +- R/resample.R | 5 +- R/worker.R | 6 ++- man-roxygen/param_bundle.R | 3 ++ man-roxygen/param_learner_properties.R | 2 + man/Learner.Rd | 29 ++++++++++++ man/LearnerClassif.Rd | 4 ++ man/LearnerRegr.Rd | 4 ++ man/benchmark.Rd | 7 ++- man/mlr_learners_classif.debug.Rd | 2 + man/mlr_learners_classif.featureless.Rd | 2 + man/mlr_learners_classif.rpart.Rd | 2 + man/mlr_learners_regr.debug.Rd | 2 + man/mlr_learners_regr.featureless.Rd | 2 + man/mlr_learners_regr.rpart.Rd | 2 + man/resample.Rd | 7 ++- tests/testthat/test_Learner.R | 61 +++++++++++++++++++++++++ tests/testthat/test_benchmark.R | 39 ++++++++++++++++ tests/testthat/test_resample.R | 37 +++++++++++++++ 23 files changed, 267 insertions(+), 10 deletions(-) create mode 100644 man-roxygen/param_bundle.R diff --git a/DESCRIPTION b/DESCRIPTION index 288f93209..cdf32f34c 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -74,7 +74,7 @@ Config/testthat/edition: 3 Config/testthat/parallel: false NeedsCompilation: no Roxygen: list(markdown = TRUE, r6 = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.2.3.9000 Collate: 'mlr_reflections.R' 'BenchmarkResult.R' diff --git a/NEWS.md b/NEWS.md index e1274404e..e6e2cdcf7 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,9 @@ # mlr3 (development version) +* Feat: added support for `"bundle"` property, which allows learners to process +models so they can be serialized. This happens automatically during `resample()` +and `benchmark()`. The naming was inspired by the {bundle} package. + # mlr3 0.17.2 * Skip new `data.table` tests on mac. diff --git a/R/Learner.R b/R/Learner.R index 9f1863548..a1b76a48c 100644 --- a/R/Learner.R +++ b/R/Learner.R @@ -164,6 +164,10 @@ Learner = R6Class("Learner", self$predict_types = assert_ordered_set(predict_types, names(mlr_reflections$learner_predict_types[[task_type]]), empty.ok = FALSE, .var.name = "predict_types") private$.predict_type = predict_types[1L] + + if (!is.null(private$.bundle) && !is.null(private$.unbundle)) { + properties = c("bundle", properties) + } self$properties = sort(assert_subset(properties, mlr_reflections$learner_properties[[task_type]])) self$data_formats = assert_subset(data_formats, mlr_reflections$data_formats) self$packages = union("mlr3", assert_character(packages, any.missing = FALSE, min.chars = 1L)) @@ -184,7 +188,7 @@ Learner = R6Class("Learner", #' @param ... (ignored). print = function(...) { catn(format(self), if (is.null(self$label) || is.na(self$label)) "" else paste0(": ", self$label)) - catn(str_indent("* Model:", if (is.null(self$model)) "-" else class(self$model)[1L])) + catn(str_indent("* Model:", if (is.null(self$model)) "-" else if (isTRUE(self$bundled)) "" else paste0(class(self$model)[1L]))) catn(str_indent("* Parameters:", as_short_string(self$param_set$values, 1000L))) catn(str_indent("* Packages:", self$packages)) catn(str_indent("* Predict Types: ", replace(self$predict_types, self$predict_types == self$predict_type, paste0("[", self$predict_type, "]")))) @@ -206,6 +210,38 @@ Learner = R6Class("Learner", open_help(self$man) }, + #' @description + #' Bundles the learner's model so it can be serialized and deserialized. + #' Does nothing if the learner does not support bundling. + bundle = function() { + if (is.null(self$model)) { + stopf("Cannot bundle, Learner '%s' has not been trained yet", self$id) + } + if (isTRUE(self$bundled)) { + lg$warn("Learner '%s' has already been bundled, skipping.", self$id) + } else if ("bundle" %in% self$properties) { + self$model = private$.bundle(self$model) + self$state$bundled = TRUE + } + invisible(self) + }, + + #' @description + #' Unbundles the learner's model so it can be used for prediction. + #' Does nothing if the learner does not support (un)bundling. + unbundle = function() { + if (is.null(self$model)) { + stopf("Cannot unbundle, Learner '%s' has not been trained yet", self$id) + } + if (isFALSE(self$bundled)) { + lg$warn("Learner '%s' has not been bundled, skipping.", self$id) + } else if (isTRUE(self$bundled)) { + self$model = private$.unbundle(self$model) + self$state$bundled = FALSE + } + invisible(self) + }, + #' @description #' Train the learner on a set of observations of the provided `task`. #' Mutates the learner by reference, i.e. stores the model alongside other information in field `$state`. @@ -279,6 +315,10 @@ Learner = R6Class("Learner", stopf("Cannot predict, Learner '%s' has not been trained yet", self$id) } + if (isTRUE(self$bundled)) { + stopf("Cannot predict, Learner '%s' has not been unbundled yet", self$id) + } + if (isTRUE(self$parallel_predict) && nbrOfWorkers() > 1L) { row_ids = row_ids %??% task$row_ids chunked = chunk_vector(row_ids, n_chunks = nbrOfWorkers(), shuffle = FALSE) @@ -388,6 +428,12 @@ Learner = R6Class("Learner", self$state$model }, + #' @field bundled (`logical(1)` or `NULL`)\cr + #' Indicates whether the model has been bundled (`TRUE`), unbudled (`FALSE`), or neither (`NULL`). + bundled = function(rhs) { + assert_ro_binding(rhs) + self$state$bundled + }, #' @field timings (named `numeric(2)`)\cr #' Elapsed time in seconds for the steps `"train"` and `"predict"`. diff --git a/R/benchmark.R b/R/benchmark.R index d1223a46f..cfe62d127 100644 --- a/R/benchmark.R +++ b/R/benchmark.R @@ -14,6 +14,7 @@ #' @template param_encapsulate #' @template param_allow_hotstart #' @template param_clone +#' @template param_bundle #' #' @return [BenchmarkResult]. #' @@ -77,7 +78,7 @@ #' ## Get the training set of the 2nd iteration of the featureless learner on penguins #' rr = bmr$aggregate()[learner_id == "classif.featureless"]$resample_result[[1]] #' rr$resampling$train_set(2) -benchmark = function(design, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling")) { +benchmark = function(design, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling"), bundle = TRUE) { assert_subset(clone, c("task", "learner", "resampling")) assert_data_frame(design, min.rows = 1L) assert_names(names(design), must.include = c("task", "learner", "resampling")) @@ -183,7 +184,7 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps res = future_map(n, workhorse, task = grid$task, learner = grid$learner, resampling = grid$resampling, iteration = grid$iteration, param_values = grid$param_values, mode = grid$mode, - MoreArgs = list(store_models = store_models, lgr_threshold = lgr_threshold, pb = pb) + MoreArgs = list(store_models = store_models, lgr_threshold = lgr_threshold, pb = pb, bundle = bundle) ) grid = insert_named(grid, list( diff --git a/R/mlr_reflections.R b/R/mlr_reflections.R index 7998c30f6..ebe30d99f 100644 --- a/R/mlr_reflections.R +++ b/R/mlr_reflections.R @@ -125,7 +125,7 @@ local({ ) ### Learner - tmp = c("featureless", "missings", "weights", "importance", "selected_features", "oob_error", "loglik", "hotstart_forward", "hotstart_backward") + tmp = c("featureless", "missings", "weights", "importance", "selected_features", "oob_error", "loglik", "hotstart_forward", "hotstart_backward", "bundle") mlr_reflections$learner_properties = list( classif = c(tmp, "twoclass", "multiclass"), regr = tmp diff --git a/R/resample.R b/R/resample.R index 9e3467e46..5a3efdd79 100644 --- a/R/resample.R +++ b/R/resample.R @@ -14,6 +14,7 @@ #' @template param_encapsulate #' @template param_allow_hotstart #' @template param_clone +#' @template param_bundle #' @return [ResampleResult]. #' #' @template section_predict_sets @@ -54,7 +55,7 @@ #' bmr1 = as_benchmark_result(rr) #' bmr2 = as_benchmark_result(rr_featureless) #' print(bmr1$combine(bmr2)) -resample = function(task, learner, resampling, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling")) { +resample = function(task, learner, resampling, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling"), bundle = TRUE) { assert_subset(clone, c("task", "learner", "resampling")) task = assert_task(as_task(task, clone = "task" %in% clone)) learner = assert_learner(as_learner(learner, clone = "learner" %in% clone)) @@ -111,7 +112,7 @@ resample = function(task, learner, resampling, store_models = FALSE, store_backe } res = future_map(n, workhorse, iteration = seq_len(n), learner = grid$learner, mode = grid$mode, - MoreArgs = list(task = task, resampling = resampling, store_models = store_models, lgr_threshold = lgr_threshold, pb = pb) + MoreArgs = list(task = task, resampling = resampling, store_models = store_models, lgr_threshold = lgr_threshold, pb = pb, bundle = bundle) ) data = data.table( diff --git a/R/worker.R b/R/worker.R index 8fa0008df..0baa85b87 100644 --- a/R/worker.R +++ b/R/worker.R @@ -72,6 +72,7 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL model = result$result, log = log, train_time = train_time, + bundled = if ("bundle" %in% learner$properties) FALSE else NULL, param_vals = learner$param_set$values, task_hash = task$hash, feature_names = task$feature_names, @@ -217,7 +218,7 @@ learner_predict = function(learner, task, row_ids = NULL) { } -workhorse = function(iteration, task, learner, resampling, param_values = NULL, lgr_threshold, store_models = FALSE, pb = NULL, mode = "train", is_sequential = TRUE) { +workhorse = function(iteration, task, learner, resampling, param_values = NULL, lgr_threshold, store_models = FALSE, pb = NULL, mode = "train", is_sequential = TRUE, bundle = TRUE) { if (!is.null(pb)) { pb(sprintf("%s|%s|i:%i", task$id, learner$id, iteration)) } @@ -266,6 +267,9 @@ workhorse = function(iteration, task, learner, resampling, param_values = NULL, if (!store_models) { lg$debug("Erasing stored model for learner '%s'", learner$id) learner$state$model = NULL + } else if (bundle && "bundle" %in% learner$properties) { + lg$debug("Bundling model for learner '%s'", learner$id) + learner$bundle() } list(learner_state = learner$state, prediction = pdatas, param_values = learner$param_set$values, learner_hash = learner_hash) diff --git a/man-roxygen/param_bundle.R b/man-roxygen/param_bundle.R new file mode 100644 index 000000000..819d8a763 --- /dev/null +++ b/man-roxygen/param_bundle.R @@ -0,0 +1,3 @@ +#' @param bundle (`logical(1)`)\cr +#' Whether to bundle the learner(s) after the train-predict loop. +#' Default is `TRUE`. diff --git a/man-roxygen/param_learner_properties.R b/man-roxygen/param_learner_properties.R index 48f3817b1..c97d97de9 100644 --- a/man-roxygen/param_learner_properties.R +++ b/man-roxygen/param_learner_properties.R @@ -7,3 +7,5 @@ #' * `"importance"`: The learner supports extraction of importance scores, i.e. comes with an `$importance()` extractor function (see section on optional extractors in [Learner]). #' * `"selected_features"`: The learner supports extraction of the set of selected features, i.e. comes with a `$selected_features()` extractor function (see section on optional extractors in [Learner]). #' * `"oob_error"`: The learner supports extraction of estimated out of bag error, i.e. comes with a `oob_error()` extractor function (see section on optional extractors in [Learner]). +#' * `"bundle"`: To save learners with this property, you need to call `$bundle()` first. +#' If a learner is in a bundled state, you call first need to call `$unbundle()` to use it's model, e.g. for prediction. diff --git a/man/Learner.Rd b/man/Learner.Rd index ac6c02aaf..7e11f1882 100644 --- a/man/Learner.Rd +++ b/man/Learner.Rd @@ -185,6 +185,9 @@ Defaults to \code{NA}, but can be set by child classes.} \item{\code{model}}{(any)\cr The fitted model. Only available after \verb{$train()} has been called.} +\item{\code{bundled}}{(\code{logical(1)} or \code{NULL})\cr +Indicates whether the model has been bundled (\code{TRUE}), unbudled (\code{FALSE}), or neither (\code{NULL}).} + \item{\code{timings}}{(named \code{numeric(2)})\cr Elapsed time in seconds for the steps \code{"train"} and \code{"predict"}. Measured via \code{\link[mlr3misc:encapsulate]{mlr3misc::encapsulate()}}.} @@ -244,6 +247,8 @@ Stores \code{HotstartStack}.} \item \href{#method-Learner-format}{\code{Learner$format()}} \item \href{#method-Learner-print}{\code{Learner$print()}} \item \href{#method-Learner-help}{\code{Learner$help()}} +\item \href{#method-Learner-bundle}{\code{Learner$bundle()}} +\item \href{#method-Learner-unbundle}{\code{Learner$unbundle()}} \item \href{#method-Learner-train}{\code{Learner$train()}} \item \href{#method-Learner-predict}{\code{Learner$predict()}} \item \href{#method-Learner-predict_newdata}{\code{Learner$predict_newdata()}} @@ -303,6 +308,8 @@ The following properties are currently standardized and understood by learners i \item \code{"importance"}: The learner supports extraction of importance scores, i.e. comes with an \verb{$importance()} extractor function (see section on optional extractors in \link{Learner}). \item \code{"selected_features"}: The learner supports extraction of the set of selected features, i.e. comes with a \verb{$selected_features()} extractor function (see section on optional extractors in \link{Learner}). \item \code{"oob_error"}: The learner supports extraction of estimated out of bag error, i.e. comes with a \code{oob_error()} extractor function (see section on optional extractors in \link{Learner}). +\item \code{"bundle"}: To save learners with this property, you need to call \verb{$bundle()} first. +If a learner is in a bundled state, you call first need to call \verb{$unbundle()} to use it's model, e.g. for prediction. }} \item{\code{data_formats}}{(\code{character()})\cr @@ -367,6 +374,28 @@ Opens the corresponding help page referenced by field \verb{$man}. \if{html}{\out{
}}\preformatted{Learner$help()}\if{html}{\out{
}} } +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Learner-bundle}{}}} +\subsection{Method \code{bundle()}}{ +Bundles the learner's model so it can be serialized and deserialized. +Does nothing if the learner does not support bundling. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Learner$bundle()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Learner-unbundle}{}}} +\subsection{Method \code{unbundle()}}{ +Unbundles the learner's model so it can be used for prediction. +Does nothing if the learner does not support (un)bundling. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Learner$unbundle()}\if{html}{\out{
}} +} + } \if{html}{\out{
}} \if{html}{\out{}} diff --git a/man/LearnerClassif.Rd b/man/LearnerClassif.Rd index f159744a6..7ef7b7703 100644 --- a/man/LearnerClassif.Rd +++ b/man/LearnerClassif.Rd @@ -85,6 +85,7 @@ Other Learner:
Inherited methods
}} @@ -139,6 +141,8 @@ The following properties are currently standardized and understood by learners i \item \code{"importance"}: The learner supports extraction of importance scores, i.e. comes with an \verb{$importance()} extractor function (see section on optional extractors in \link{Learner}). \item \code{"selected_features"}: The learner supports extraction of the set of selected features, i.e. comes with a \verb{$selected_features()} extractor function (see section on optional extractors in \link{Learner}). \item \code{"oob_error"}: The learner supports extraction of estimated out of bag error, i.e. comes with a \code{oob_error()} extractor function (see section on optional extractors in \link{Learner}). +\item \code{"bundle"}: To save learners with this property, you need to call \verb{$bundle()} first. +If a learner is in a bundled state, you call first need to call \verb{$unbundle()} to use it's model, e.g. for prediction. }} \item{\code{data_formats}}{(\code{character()})\cr diff --git a/man/LearnerRegr.Rd b/man/LearnerRegr.Rd index 5c9cbbcbf..700edf847 100644 --- a/man/LearnerRegr.Rd +++ b/man/LearnerRegr.Rd @@ -75,6 +75,7 @@ Other Learner:
Inherited methods
}} @@ -129,6 +131,8 @@ The following properties are currently standardized and understood by learners i \item \code{"importance"}: The learner supports extraction of importance scores, i.e. comes with an \verb{$importance()} extractor function (see section on optional extractors in \link{Learner}). \item \code{"selected_features"}: The learner supports extraction of the set of selected features, i.e. comes with a \verb{$selected_features()} extractor function (see section on optional extractors in \link{Learner}). \item \code{"oob_error"}: The learner supports extraction of estimated out of bag error, i.e. comes with a \code{oob_error()} extractor function (see section on optional extractors in \link{Learner}). +\item \code{"bundle"}: To save learners with this property, you need to call \verb{$bundle()} first. +If a learner is in a bundled state, you call first need to call \verb{$unbundle()} to use it's model, e.g. for prediction. }} \item{\code{data_formats}}{(\code{character()})\cr diff --git a/man/benchmark.Rd b/man/benchmark.Rd index 2bd1a35db..edb37a682 100644 --- a/man/benchmark.Rd +++ b/man/benchmark.Rd @@ -10,7 +10,8 @@ benchmark( store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, - clone = c("task", "learner", "resampling") + clone = c("task", "learner", "resampling"), + bundle = TRUE ) } \arguments{ @@ -57,6 +58,10 @@ Select the input objects to be cloned before proceeding by providing a set with possible values \code{"task"}, \code{"learner"} and \code{"resampling"} for \link{Task}, \link{Learner} and \link{Resampling}, respectively. Per default, all input objects are cloned.} + +\item{bundle}{(\code{logical(1)})\cr +Whether to bundle the learner(s) after the train-predict loop. +Default is \code{TRUE}.} } \value{ \link{BenchmarkResult}. diff --git a/man/mlr_learners_classif.debug.Rd b/man/mlr_learners_classif.debug.Rd index f5ca719ea..3b66b12d3 100644 --- a/man/mlr_learners_classif.debug.Rd +++ b/man/mlr_learners_classif.debug.Rd @@ -127,6 +127,7 @@ Other Learner:
Inherited methods
}} diff --git a/man/mlr_learners_classif.featureless.Rd b/man/mlr_learners_classif.featureless.Rd index 94c1b8595..53441620b 100644 --- a/man/mlr_learners_classif.featureless.Rd +++ b/man/mlr_learners_classif.featureless.Rd @@ -95,6 +95,7 @@ Other Learner:
Inherited methods
}} diff --git a/man/mlr_learners_classif.rpart.Rd b/man/mlr_learners_classif.rpart.Rd index 139dfe41c..614a73b49 100644 --- a/man/mlr_learners_classif.rpart.Rd +++ b/man/mlr_learners_classif.rpart.Rd @@ -109,6 +109,7 @@ Other Learner:
Inherited methods
}} diff --git a/man/mlr_learners_regr.debug.Rd b/man/mlr_learners_regr.debug.Rd index 66ffd60f2..e0f9a7f4d 100644 --- a/man/mlr_learners_regr.debug.Rd +++ b/man/mlr_learners_regr.debug.Rd @@ -100,6 +100,7 @@ Other Learner:
Inherited methods
}} diff --git a/man/mlr_learners_regr.featureless.Rd b/man/mlr_learners_regr.featureless.Rd index f06fcf4d6..20e840b74 100644 --- a/man/mlr_learners_regr.featureless.Rd +++ b/man/mlr_learners_regr.featureless.Rd @@ -84,6 +84,7 @@ Other Learner:
Inherited methods
}} diff --git a/man/mlr_learners_regr.rpart.Rd b/man/mlr_learners_regr.rpart.Rd index 9a83fec05..04f4c8997 100644 --- a/man/mlr_learners_regr.rpart.Rd +++ b/man/mlr_learners_regr.rpart.Rd @@ -109,6 +109,7 @@ Other Learner:
Inherited methods
}} diff --git a/man/resample.Rd b/man/resample.Rd index 825f086da..935c6a97c 100644 --- a/man/resample.Rd +++ b/man/resample.Rd @@ -12,7 +12,8 @@ resample( store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, - clone = c("task", "learner", "resampling") + clone = c("task", "learner", "resampling"), + bundle = TRUE ) } \arguments{ @@ -58,6 +59,10 @@ Select the input objects to be cloned before proceeding by providing a set with possible values \code{"task"}, \code{"learner"} and \code{"resampling"} for \link{Task}, \link{Learner} and \link{Resampling}, respectively. Per default, all input objects are cloned.} + +\item{bundle}{(\code{logical(1)})\cr +Whether to bundle the learner(s) after the train-predict loop. +Default is \code{TRUE}.} } \value{ \link{ResampleResult}. diff --git a/tests/testthat/test_Learner.R b/tests/testthat/test_Learner.R index bdd5f709d..7d1c14435 100644 --- a/tests/testthat/test_Learner.R +++ b/tests/testthat/test_Learner.R @@ -324,3 +324,64 @@ test_that("Models can be replaced", { learner$model$location = 1 expect_equal(learner$model$location, 1) }) + +test_that("bundling", { + task = tsk("mtcars") + LearnerRegrTest = R6Class("LearnerRegrTest", + inherit = LearnerRegrFeatureless, + private = list( + .bundle = function(model) { + private$.tmp_model = model + "bundle" + }, + .unbundle = function(model) { + model = private$.tmp_model + private$.tmp_model = NULL + model + }, + .tmp_model = NULL + ) + ) + + # bundled property + learner = LearnerRegrTest$new() + expect_true("bundle" %in% learner$properties) + expect_true(is.null(learner$bundled)) + + # (un)bundling only possible after training + expect_error(learner$bundle(), "has not been trained") + expect_error(learner$unbundle(), "has not been trained") + + learner$train(task) + model = learner$model + expect_false(learner$bundled) + learner$bundle() + expect_true(learner$bundled) + + # cannot predict with bundled learner + expect_error(learner$predict(task), "has not been unbundled") + expect_true(identical(learner$model, "bundle")) + + # unbundling works + learner$unbundle() + # can predict after unbundling + expect_prediction(learner$predict(task)) + # model is reset + expect_equal(learner$model, model) + # bundled is set accordingly + expect_false(learner$bundled) + + # when re-training, bundled is reset + learner$predict(task) + expect_false(learner$train(task)$bundled) + + ## unbundleable learners + lrn_rpart = lrn("regr.rpart") + + # bundled is NULL + expect_true(is.null(lrn_rpart$bundled)) + expect_true(is.null(lrn_rpart$train(task)$bundled)) + # calling (un)bundle does nothing + expect_true(is.null(lrn_rpart$bundle()$bundled)) + expect_true(is.null(lrn_rpart$unbundle()$bundled)) +}) diff --git a/tests/testthat/test_benchmark.R b/tests/testthat/test_benchmark.R index 2faaa19f6..9d1841238 100644 --- a/tests/testthat/test_benchmark.R +++ b/tests/testthat/test_benchmark.R @@ -476,3 +476,42 @@ test_that("param_values in benchmark", { expect_equal(bmr$learners$learner[[1]]$param_set$values, list(xval = 0, minsplit = 12, minbucket = 2)) expect_equal(bmr$learners$learner[[2]]$param_set$values, list(xval = 0, minsplit = 12, cp = 0.1)) }) + + +test_that("bundling", { + task = tsk("mtcars") + LearnerRegrTest = R6Class("LearnerRegrTest", + inherit = LearnerRegrFeatureless, + private = list( + .bundle = function(model) { + private$.tmp_model = model + "bundle" + }, + .unbundle = function(model) { + model = private$.tmp_model + private$.tmp_model = NULL + model + }, + .tmp_model = NULL + ) + ) + learner = LearnerRegrTest$new() + resampling = rsmp("holdout")$instantiate(task) + + # Learner can be bundled during benchmark() + bmr1 = benchmark(benchmark_grid(task, learner, resampling), store_models = TRUE, bundle = TRUE) + lrn_rec = bmr1$resample_results$resample_result[[1]]$learners[[1]] + expect_true(lrn_rec$bundled) + + # learner can be unbundled after benchmark() + lrn_rec$unbundle() + expect_false(lrn_rec$bundled) + + # result is the same with and without bundling + bmr2 = benchmark(benchmark_grid(task, lrn("regr.featureless"), resampling), store_models = TRUE, bundle = TRUE) + expect_equal(as.data.table(bmr1$aggregate())$regr.mse, as.data.table(bmr2$aggregate())$regr.mse) + + # bundling can be disabled + bmr3 = benchmark(benchmark_grid(task, learner, resampling), store_models = TRUE, bundle = FALSE) + expect_false(bmr3$resample_results$resample_result[[1]]$learners[[1]]$bundled) +}) diff --git a/tests/testthat/test_resample.R b/tests/testthat/test_resample.R index 91cefa4f7..48c38f87f 100644 --- a/tests/testthat/test_resample.R +++ b/tests/testthat/test_resample.R @@ -156,3 +156,40 @@ test_that("as_resample_result works for result data", { rr2 = as_resample_result(result_data) expect_class(rr2, "ResampleResult") }) + +test_that("bundling", { + task = tsk("mtcars") + LearnerRegrTest = R6Class("LearnerRegrTest", + inherit = LearnerRegrFeatureless, + private = list( + .bundle = function(model) { + private$.tmp_model = model + "bundle" + }, + .unbundle = function(model) { + model = private$.tmp_model + private$.tmp_model = NULL + model + }, + .tmp_model = NULL + ) + ) + learner = LearnerRegrTest$new() + + # allow to bundle during resample() + resampling = rsmp("holdout")$instantiate(task) + rr1 = resample(task, learner, resampling, store_models = TRUE, bundle = TRUE) + lrn_rec = rr1$learners[[1L]] + expect_true(lrn_rec$bundled) + expect_false(lrn_rec$unbundle()$bundled) + + # bundled resamples are equivalent to unbundled results + rr2 = resample(task, lrn("regr.featureless"), resampling, store_models = TRUE) + rr3 = resample(task, learner, resampling, store_models = FALSE) + expect_equal(rr1$aggregate(), rr3$aggregate()) + expect_equal(rr1$aggregate(), rr2$aggregate()) + + # bundling can be disabled + rr4 = resample(task, learner, resampling, bundle = FALSE, store_models = TRUE) + expect_false(rr4$learners[[1]]$bundled) +})