Skip to content

Commit

Permalink
feat(Learner): support bundling property
Browse files Browse the repository at this point in the history
This is useful for learners that e.g. rely on external pointers
and that need so processing of the model to be serializable.
  • Loading branch information
sebffischer committed Jan 25, 2024
1 parent 78ad4dd commit 8fb3818
Show file tree
Hide file tree
Showing 23 changed files with 267 additions and 10 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ Config/testthat/edition: 3
Config/testthat/parallel: false
NeedsCompilation: no
Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.2.3.9000
Collate:
'mlr_reflections.R'
'BenchmarkResult.R'
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# mlr3 (development version)

* Feat: added support for `"bundle"` property, which allows learners to process
models so they can be serialized. This happens automatically during `resample()`
and `benchmark()`. The naming was inspired by the {bundle} package.

# mlr3 0.17.2

* Skip new `data.table` tests on mac.
Expand Down
48 changes: 47 additions & 1 deletion R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ Learner = R6Class("Learner",
self$predict_types = assert_ordered_set(predict_types, names(mlr_reflections$learner_predict_types[[task_type]]),
empty.ok = FALSE, .var.name = "predict_types")
private$.predict_type = predict_types[1L]

if (!is.null(private$.bundle) && !is.null(private$.unbundle)) {
properties = c("bundle", properties)
}
self$properties = sort(assert_subset(properties, mlr_reflections$learner_properties[[task_type]]))
self$data_formats = assert_subset(data_formats, mlr_reflections$data_formats)
self$packages = union("mlr3", assert_character(packages, any.missing = FALSE, min.chars = 1L))
Expand All @@ -184,7 +188,7 @@ Learner = R6Class("Learner",
#' @param ... (ignored).
print = function(...) {
catn(format(self), if (is.null(self$label) || is.na(self$label)) "" else paste0(": ", self$label))
catn(str_indent("* Model:", if (is.null(self$model)) "-" else class(self$model)[1L]))
catn(str_indent("* Model:", if (is.null(self$model)) "-" else if (isTRUE(self$bundled)) "<bundled>" else paste0(class(self$model)[1L])))
catn(str_indent("* Parameters:", as_short_string(self$param_set$values, 1000L)))
catn(str_indent("* Packages:", self$packages))
catn(str_indent("* Predict Types: ", replace(self$predict_types, self$predict_types == self$predict_type, paste0("[", self$predict_type, "]"))))
Expand All @@ -206,6 +210,38 @@ Learner = R6Class("Learner",
open_help(self$man)
},

#' @description
#' Bundles the learner's model so it can be serialized and deserialized.
#' Does nothing if the learner does not support bundling.
bundle = function() {
if (is.null(self$model)) {
stopf("Cannot bundle, Learner '%s' has not been trained yet", self$id)
}
if (isTRUE(self$bundled)) {
lg$warn("Learner '%s' has already been bundled, skipping.", self$id)
} else if ("bundle" %in% self$properties) {
self$model = private$.bundle(self$model)
self$state$bundled = TRUE
}
invisible(self)
},

#' @description
#' Unbundles the learner's model so it can be used for prediction.
#' Does nothing if the learner does not support (un)bundling.
unbundle = function() {
if (is.null(self$model)) {
stopf("Cannot unbundle, Learner '%s' has not been trained yet", self$id)
}
if (isFALSE(self$bundled)) {
lg$warn("Learner '%s' has not been bundled, skipping.", self$id)
} else if (isTRUE(self$bundled)) {
self$model = private$.unbundle(self$model)
self$state$bundled = FALSE
}
invisible(self)
},

#' @description
#' Train the learner on a set of observations of the provided `task`.
#' Mutates the learner by reference, i.e. stores the model alongside other information in field `$state`.
Expand Down Expand Up @@ -279,6 +315,10 @@ Learner = R6Class("Learner",
stopf("Cannot predict, Learner '%s' has not been trained yet", self$id)
}

if (isTRUE(self$bundled)) {
stopf("Cannot predict, Learner '%s' has not been unbundled yet", self$id)
}

if (isTRUE(self$parallel_predict) && nbrOfWorkers() > 1L) {
row_ids = row_ids %??% task$row_ids
chunked = chunk_vector(row_ids, n_chunks = nbrOfWorkers(), shuffle = FALSE)
Expand Down Expand Up @@ -388,6 +428,12 @@ Learner = R6Class("Learner",
self$state$model
},

#' @field bundled (`logical(1)` or `NULL`)\cr
#' Indicates whether the model has been bundled (`TRUE`), unbudled (`FALSE`), or neither (`NULL`).
bundled = function(rhs) {
assert_ro_binding(rhs)
self$state$bundled
},

#' @field timings (named `numeric(2)`)\cr
#' Elapsed time in seconds for the steps `"train"` and `"predict"`.
Expand Down
5 changes: 3 additions & 2 deletions R/benchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#' @template param_encapsulate
#' @template param_allow_hotstart
#' @template param_clone
#' @template param_bundle
#'
#' @return [BenchmarkResult].
#'
Expand Down Expand Up @@ -77,7 +78,7 @@
#' ## Get the training set of the 2nd iteration of the featureless learner on penguins
#' rr = bmr$aggregate()[learner_id == "classif.featureless"]$resample_result[[1]]
#' rr$resampling$train_set(2)
benchmark = function(design, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling")) {
benchmark = function(design, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling"), bundle = TRUE) {
assert_subset(clone, c("task", "learner", "resampling"))
assert_data_frame(design, min.rows = 1L)
assert_names(names(design), must.include = c("task", "learner", "resampling"))
Expand Down Expand Up @@ -183,7 +184,7 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps

res = future_map(n, workhorse,
task = grid$task, learner = grid$learner, resampling = grid$resampling, iteration = grid$iteration, param_values = grid$param_values, mode = grid$mode,
MoreArgs = list(store_models = store_models, lgr_threshold = lgr_threshold, pb = pb)
MoreArgs = list(store_models = store_models, lgr_threshold = lgr_threshold, pb = pb, bundle = bundle)
)

grid = insert_named(grid, list(
Expand Down
2 changes: 1 addition & 1 deletion R/mlr_reflections.R
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ local({
)

### Learner
tmp = c("featureless", "missings", "weights", "importance", "selected_features", "oob_error", "loglik", "hotstart_forward", "hotstart_backward")
tmp = c("featureless", "missings", "weights", "importance", "selected_features", "oob_error", "loglik", "hotstart_forward", "hotstart_backward", "bundle")
mlr_reflections$learner_properties = list(
classif = c(tmp, "twoclass", "multiclass"),
regr = tmp
Expand Down
5 changes: 3 additions & 2 deletions R/resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#' @template param_encapsulate
#' @template param_allow_hotstart
#' @template param_clone
#' @template param_bundle
#' @return [ResampleResult].
#'
#' @template section_predict_sets
Expand Down Expand Up @@ -54,7 +55,7 @@
#' bmr1 = as_benchmark_result(rr)
#' bmr2 = as_benchmark_result(rr_featureless)
#' print(bmr1$combine(bmr2))
resample = function(task, learner, resampling, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling")) {
resample = function(task, learner, resampling, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling"), bundle = TRUE) {
assert_subset(clone, c("task", "learner", "resampling"))
task = assert_task(as_task(task, clone = "task" %in% clone))
learner = assert_learner(as_learner(learner, clone = "learner" %in% clone))
Expand Down Expand Up @@ -111,7 +112,7 @@ resample = function(task, learner, resampling, store_models = FALSE, store_backe
}

res = future_map(n, workhorse, iteration = seq_len(n), learner = grid$learner, mode = grid$mode,
MoreArgs = list(task = task, resampling = resampling, store_models = store_models, lgr_threshold = lgr_threshold, pb = pb)
MoreArgs = list(task = task, resampling = resampling, store_models = store_models, lgr_threshold = lgr_threshold, pb = pb, bundle = bundle)
)

data = data.table(
Expand Down
6 changes: 5 additions & 1 deletion R/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL
model = result$result,
log = log,
train_time = train_time,
bundled = if ("bundle" %in% learner$properties) FALSE else NULL,
param_vals = learner$param_set$values,
task_hash = task$hash,
feature_names = task$feature_names,
Expand Down Expand Up @@ -217,7 +218,7 @@ learner_predict = function(learner, task, row_ids = NULL) {
}


workhorse = function(iteration, task, learner, resampling, param_values = NULL, lgr_threshold, store_models = FALSE, pb = NULL, mode = "train", is_sequential = TRUE) {
workhorse = function(iteration, task, learner, resampling, param_values = NULL, lgr_threshold, store_models = FALSE, pb = NULL, mode = "train", is_sequential = TRUE, bundle = TRUE) {
if (!is.null(pb)) {
pb(sprintf("%s|%s|i:%i", task$id, learner$id, iteration))
}
Expand Down Expand Up @@ -266,6 +267,9 @@ workhorse = function(iteration, task, learner, resampling, param_values = NULL,
if (!store_models) {
lg$debug("Erasing stored model for learner '%s'", learner$id)
learner$state$model = NULL
} else if (bundle && "bundle" %in% learner$properties) {
lg$debug("Bundling model for learner '%s'", learner$id)
learner$bundle()
}

list(learner_state = learner$state, prediction = pdatas, param_values = learner$param_set$values, learner_hash = learner_hash)
Expand Down
3 changes: 3 additions & 0 deletions man-roxygen/param_bundle.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#' @param bundle (`logical(1)`)\cr
#' Whether to bundle the learner(s) after the train-predict loop.
#' Default is `TRUE`.
2 changes: 2 additions & 0 deletions man-roxygen/param_learner_properties.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@
#' * `"importance"`: The learner supports extraction of importance scores, i.e. comes with an `$importance()` extractor function (see section on optional extractors in [Learner]).
#' * `"selected_features"`: The learner supports extraction of the set of selected features, i.e. comes with a `$selected_features()` extractor function (see section on optional extractors in [Learner]).
#' * `"oob_error"`: The learner supports extraction of estimated out of bag error, i.e. comes with a `oob_error()` extractor function (see section on optional extractors in [Learner]).
#' * `"bundle"`: To save learners with this property, you need to call `$bundle()` first.
#' If a learner is in a bundled state, you call first need to call `$unbundle()` to use it's model, e.g. for prediction.
29 changes: 29 additions & 0 deletions man/Learner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions man/LearnerClassif.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions man/LearnerRegr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion man/benchmark.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 8fb3818

Please sign in to comment.