Skip to content

Commit

Permalink
fix(bundle): always (un)bundle for callr encapsulation
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Jan 25, 2024
1 parent 8fb3818 commit 851e64c
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
15 changes: 14 additions & 1 deletion R/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL
stopf("Learner '%s' on task '%s' returned NULL during internal %s()", learner$id, task$id, mode)
}

if ("bundle" %in% learner$properties && identical(learner$encapsulate[["train"]], "callr")) {
model = get_private(learner)$.bundle(model)
}

model
}

Expand Down Expand Up @@ -68,11 +72,16 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL
log = append_log(NULL, "train", result$log$class, result$log$msg)
train_time = result$elapsed

# callr encapsualtion causes dangling pointers between train and predict
bundled = if ("bundle" %in% learner$properties) {
identical(learner$encapsulate[["train"]], "callr")
}

learner$state = insert_named(learner$state, list(
model = result$result,
log = log,
train_time = train_time,
bundled = if ("bundle" %in% learner$properties) FALSE else NULL,
bundled = bundled,
param_vals = learner$param_set$values,
task_hash = task$hash,
feature_names = task$feature_names,
Expand Down Expand Up @@ -102,6 +111,10 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL
fb$id, learner = fb$clone())
}

if (isTRUE(bundled)) {
learner$unbundle()
}

learner
}

Expand Down
26 changes: 25 additions & 1 deletion tests/testthat/test_Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,12 @@ test_that("bundling", {
task = tsk("mtcars")
LearnerRegrTest = R6Class("LearnerRegrTest",
inherit = LearnerRegrFeatureless,
public = list(
initialize = function() {
super$initialize()
self$id = "regr.test"
}
),
private = list(
.bundle = function(model) {
private$.tmp_model = model
Expand All @@ -337,9 +343,11 @@ test_that("bundling", {
.unbundle = function(model) {
model = private$.tmp_model
private$.tmp_model = NULL
private$.counter = private$.counter + 1
model
},
.tmp_model = NULL
.tmp_model = NULL,
.counter = 0
)
)

Expand Down Expand Up @@ -384,4 +392,20 @@ test_that("bundling", {
# calling (un)bundle does nothing
expect_true(is.null(lrn_rpart$bundle()$bundled))
expect_true(is.null(lrn_rpart$unbundle()$bundled))

# callr encapsulation causes bundling
learner2 = LearnerRegrTest$new()
learner2$encapsulate = c(train = "callr")
learner2$train(task)
expect_false(learner2$bundled)

learner3 = LearnerRegrTest$new()
learner3$encapsulate = c(train = "try")
learner3$train(task)
expect_false(learner3$bundled)

# for callr, we had to unbundle
expect_equal(get_private(learner2)$.counter, 1)
# for other encapsulation, no need to unbundle becausse it was not bundled
expect_equal(get_private(learner3)$.counter, 0)
})

0 comments on commit 851e64c

Please sign in to comment.