diff --git a/R/worker.R b/R/worker.R index 0baa85b87..61065b2e9 100644 --- a/R/worker.R +++ b/R/worker.R @@ -18,6 +18,10 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL stopf("Learner '%s' on task '%s' returned NULL during internal %s()", learner$id, task$id, mode) } + if ("bundle" %in% learner$properties && identical(learner$encapsulate[["train"]], "callr")) { + model = get_private(learner)$.bundle(model) + } + model } @@ -68,11 +72,16 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL log = append_log(NULL, "train", result$log$class, result$log$msg) train_time = result$elapsed + # callr encapsualtion causes dangling pointers between train and predict + bundled = if ("bundle" %in% learner$properties) { + identical(learner$encapsulate[["train"]], "callr") + } + learner$state = insert_named(learner$state, list( model = result$result, log = log, train_time = train_time, - bundled = if ("bundle" %in% learner$properties) FALSE else NULL, + bundled = bundled, param_vals = learner$param_set$values, task_hash = task$hash, feature_names = task$feature_names, @@ -102,6 +111,10 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL fb$id, learner = fb$clone()) } + if (isTRUE(bundled)) { + learner$unbundle() + } + learner } diff --git a/tests/testthat/test_Learner.R b/tests/testthat/test_Learner.R index 7d1c14435..dd818884e 100644 --- a/tests/testthat/test_Learner.R +++ b/tests/testthat/test_Learner.R @@ -329,6 +329,12 @@ test_that("bundling", { task = tsk("mtcars") LearnerRegrTest = R6Class("LearnerRegrTest", inherit = LearnerRegrFeatureless, + public = list( + initialize = function() { + super$initialize() + self$id = "regr.test" + } + ), private = list( .bundle = function(model) { private$.tmp_model = model @@ -337,9 +343,11 @@ test_that("bundling", { .unbundle = function(model) { model = private$.tmp_model private$.tmp_model = NULL + private$.counter = private$.counter + 1 model }, - .tmp_model = NULL + .tmp_model = NULL, + .counter = 0 ) ) @@ -384,4 +392,20 @@ test_that("bundling", { # calling (un)bundle does nothing expect_true(is.null(lrn_rpart$bundle()$bundled)) expect_true(is.null(lrn_rpart$unbundle()$bundled)) + + # callr encapsulation causes bundling + learner2 = LearnerRegrTest$new() + learner2$encapsulate = c(train = "callr") + learner2$train(task) + expect_false(learner2$bundled) + + learner3 = LearnerRegrTest$new() + learner3$encapsulate = c(train = "try") + learner3$train(task) + expect_false(learner3$bundled) + + # for callr, we had to unbundle + expect_equal(get_private(learner2)$.counter, 1) + # for other encapsulation, no need to unbundle becausse it was not bundled + expect_equal(get_private(learner3)$.counter, 0) })