Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Nov 27, 2024
1 parent a2ad28f commit 01e6a4b
Show file tree
Hide file tree
Showing 11 changed files with 141 additions and 57 deletions.
5 changes: 5 additions & 0 deletions R/ContextEvaluation.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ ContextEvaluation = R6Class("ContextEvaluation",
#' The data is available on stage `on_evaluation_end`.
pdatas = NULL,

#' @field data_extra (list())\cr
#' Data saved in the [ResampleResult] or [BenchmarkResult].
#' Use this field to save results.
data_extra = NULL,

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
Expand Down
10 changes: 8 additions & 2 deletions R/ResampleResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,12 @@ ResampleResult = R6Class("ResampleResult",
private$.data$learners(private$.view)$learner
},

#' @field data_extra (list())\cr
#' Additional data stored in the [ResampleResult].
data_extra = function() {
private$.data$data_extra(private$.view)
},

#' @field warnings ([data.table::data.table()])\cr
#' A table with all warning messages.
#' Column names are `"iteration"` and `"msg"`.
Expand Down Expand Up @@ -370,10 +376,10 @@ ResampleResult = R6Class("ResampleResult",
)

#' @export
as.data.table.ResampleResult = function(x, ..., predict_sets = "test") { # nolint
as.data.table.ResampleResult = function(x, ..., predict_sets = "test", data_extra = FALSE) { # nolint
private = get_private(x)
tab = private$.data$as_data_table(view = private$.view, predict_sets = predict_sets)
tab[, c("task", "learner", "resampling", "iteration", "prediction"), with = FALSE]
tab[, c("task", "learner", "resampling", "iteration", "prediction", if (data_extra) "data_extra"), with = FALSE]
}

# #' @export
Expand Down
18 changes: 14 additions & 4 deletions R/ResultData.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
#' print(ResultData$new()$data)
ResultData = R6Class("ResultData",
public = list(

#' @field data (`list()`)\cr
#' List of [data.table::data.table()], arranged in a star schema.
#' Do not operate directly on this list.
data = NULL,


#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#' An alternative construction method is provided by [as_result_data()].
Expand All @@ -40,12 +40,12 @@ ResultData = R6Class("ResultData",
self$data = star_init()
} else {
assert_names(names(data),
permutation.of = c("task", "learner", "learner_state", "resampling", "iteration", "param_values", "prediction", "uhash", "learner_hash"))
permutation.of = c("task", "learner", "learner_state", "resampling", "iteration", "param_values", "prediction", "uhash", "learner_hash", "data_extra"))

if (nrow(data) == 0L) {
self$data = star_init()
} else {
setcolorder(data, c("uhash", "iteration", "learner_state", "prediction", "task", "learner", "resampling", "param_values", "learner_hash"))
setcolorder(data, c("uhash", "iteration", "learner_state", "prediction", "data_extra", "task", "learner", "resampling", "param_values", "learner_hash"))
uhashes = data.table(uhash = unique(data$uhash))
setkeyv(data, c("uhash", "iteration"))

Expand Down Expand Up @@ -189,6 +189,15 @@ ResultData = R6Class("ResultData",
do.call(c, self$predictions(view = view, predict_sets = predict_sets))
},

#' @description
#' Returns additional data stored.
#'
#' @return `list()`.
data_extra = function(view = NULL) {
.__ii__ = private$get_view_index(view)
self$data$fact[.__ii__, "data_extra", with = FALSE][[1L]]
},

#' @description
#' Combines multiple [ResultData] objects, modifying `self` in-place.
#'
Expand Down Expand Up @@ -315,7 +324,7 @@ ResultData = R6Class("ResultData",
}

cns = c("uhash", "task", "task_hash", "learner", "learner_hash", "learner_param_vals", "resampling",
"resampling_hash", "iteration", "prediction")
"resampling_hash", "iteration", "prediction", "data_extra")
merge(self$data$uhashes, tab[, cns, with = FALSE], by = "uhash", sort = FALSE)
},

Expand Down Expand Up @@ -375,6 +384,7 @@ star_init = function() {
iteration = integer(),
learner_state = list(),
prediction = list(),
data_extra = list(),

learner_hash = character(),
task_hash = character(),
Expand Down
3 changes: 2 additions & 1 deletion R/resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ resample = function(
prediction = map(res, "prediction"),
uhash = UUIDgenerate(),
param_values = map(res, "param_values"),
learner_hash = map_chr(res, "learner_hash")
learner_hash = map_chr(res, "learner_hash"),
data_extra = map(res, "data_extra")
)

result_data = ResultData$new(data, store_backends = store_backends)
Expand Down
7 changes: 6 additions & 1 deletion R/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,12 @@ workhorse = function(

learner_state = set_class(learner$state, c("learner_state", "list"))

list(learner_state = learner_state, prediction = pdatas, param_values = learner$param_set$values, learner_hash = learner_hash)
list(
learner_state = learner_state,
prediction = pdatas,
param_values = learner$param_set$values,
learner_hash = learner_hash,
data_extra = ctx$data_extra)
}

# creates the tasks and row ids for the selected predict sets
Expand Down
4 changes: 4 additions & 0 deletions man/ContextEvaluation.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/ResampleResult.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 22 additions & 0 deletions man/ResultData.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions pkgdown/_pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ reference:
- mlr_sugar
- mlr_reflections
- set_threads
- title: Callbacks
contents:
- CallbackEvaluation
- ContextEvaluation
- title: Internal Objects and Functions
contents:
- marshaling
Expand Down
120 changes: 71 additions & 49 deletions tests/testthat/test_CallbackEvaluation.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@ test_that("on_evaluation_begin works", {
callback = callback_evaluation("test",

on_evaluation_begin = function(callback, context) {
expect_task(context$task)
expect_learner(context$learner)
expect_resampling(context$resampling)
expect_null(context$param_values)
expect_null(context$sets)
expect_null(context$test_set)
expect_null(context$predict_sets)
expect_null(context$pdatas)
# expect_* does not work
assert_task(context$task)
assert_learner(context$learner)
assert_resampling(context$resampling)
assert_null(context$param_values)
assert_null(context$sets)
assert_null(context$test_set)
assert_null(context$predict_sets)
assert_null(context$pdatas)
}
)

resample(task, learner, resampling, callbacks = callback)

expect_resample_result(resample(task, learner, resampling, callbacks = callback))
})

test_that("on_evaluation_before_train works", {
Expand All @@ -29,21 +29,21 @@ test_that("on_evaluation_before_train works", {
callback = callback_evaluation("test",

on_evaluation_before_train = function(callback, context) {
expect_task(context$task)
expect_learner(context$learner)
expect_resampling(context$resampling)
expect_null(context$param_values)
expect_list(context$sets, len = 2)
expect_equal(names(context$sets), c("train", "test"))
expect_integer(context$sets$train)
expect_integer(context$sets$test)
expect_null(context$test_set)
expect_null(context$predict_sets)
expect_null(context$pdatas)
assert_task(context$task)
assert_learner(context$learner)
assert_resampling(context$resampling)
assert_null(context$param_values)
assert_list(context$sets, len = 2)
assert_names(names(context$sets), identical.to = c("train", "test"))
assert_integer(context$sets$train)
assert_integer(context$sets$test)
assert_null(context$test_set)
assert_null(context$predict_sets)
assert_null(context$pdatas)
}
)

resample(task, learner, resampling, callbacks = callback)
expect_resample_result(resample(task, learner, resampling, callbacks = callback))

})

Expand All @@ -55,22 +55,22 @@ test_that("on_evaluation_before_predict works", {
callback = callback_evaluation("test",

on_evaluation_before_predict = function(callback, context) {
expect_task(context$task)
expect_learner(context$learner)
expect_resampling(context$resampling)
expect_null(context$param_values)
expect_list(context$sets, len = 2)
expect_equal(names(context$sets), c("train", "test"))
expect_integer(context$sets$train)
expect_integer(context$sets$test)
expect_class(context$learner$model, "rpart")
expect_null(context$test_set)
expect_equal(context$predict_sets, "test")
expect_null(context$pdatas)
assert_task(context$task)
assert_learner(context$learner)
assert_resampling(context$resampling)
assert_null(context$param_values)
assert_list(context$sets, len = 2)
assert_names(names(context$sets), identical.to = c("train", "test"))
assert_integer(context$sets$train)
assert_integer(context$sets$test)
assert_class(context$learner$model, "rpart")
assert_null(context$test_set)
assert_true(context$predict_sets == "test")
assert_null(context$pdatas)
}
)

resample(task, learner, resampling, callbacks = callback)
expect_resample_result(resample(task, learner, resampling, callbacks = callback))
})

test_that("on_evaluation_end works", {
Expand All @@ -81,22 +81,22 @@ test_that("on_evaluation_end works", {
callback = callback_evaluation("test",

on_evaluation_end = function(callback, context) {
expect_task(context$task)
expect_learner(context$learner)
expect_resampling(context$resampling)
expect_null(context$param_values)
expect_list(context$sets, len = 2)
expect_equal(names(context$sets), c("train", "test"))
expect_integer(context$sets$train)
expect_integer(context$sets$test)
expect_class(context$learner$model, "rpart")
expect_null(context$test_set)
expect_equal(context$predict_sets, "test")
expect_class(context$pdatas$test, "PredictionData")
assert_task(context$task)
assert_learner(context$learner)
assert_resampling(context$resampling)
assert_null(context$param_values)
assert_list(context$sets, len = 2)
assert_names(names(context$sets), identical.to = c("train", "test"))
assert_integer(context$sets$train)
assert_integer(context$sets$test)
assert_class(context$learner$model, "rpart")
assert_null(context$test_set)
assert_true(context$predict_sets == "test")
assert_class(context$pdatas$test, "PredictionData")
}
)

resample(task, learner, resampling, callbacks = callback)
expect_resample_result(resample(task, learner, resampling, callbacks = callback))
})

test_that("writing to learner$state works", {
Expand All @@ -112,5 +112,27 @@ test_that("writing to learner$state works", {
)

rr = resample(task, learner, resampling, callbacks = callback)
expect_equal(rr$learners[[1]]$state$test, 1)

walk(rr$learners, function(learner) {
expect_equal(learner$state$test, 1)
})
})

test_that("writing to data_extra works", {
task = tsk("pima")
learner = lrn("classif.rpart")
resampling = rsmp("cv", folds = 3)

callback = callback_evaluation("test",

on_evaluation_end = function(callback, context) {
context$data_extra$test = 1
}
)

rr = resample(task, learner, resampling, callbacks = callback)

walk(rr$data_extra, function(x) {
expect_equal(x$test, 1)
})
})
2 changes: 2 additions & 0 deletions tests/testthat/test_mlr_callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ test_that("score_measure works", {
walk(rr$learners, function(learner) {
expect_number(learner$state$selected_features)
})

expect_names(names(as.data.table(rr, data_extra = TRUE)), must.include = "data_extra")
})

0 comments on commit 01e6a4b

Please sign in to comment.