Skip to content

Commit

Permalink
refactor: returned predictions with different sets (#1113)
Browse files Browse the repository at this point in the history
* refactor: returned predictions with different sets

* docs: empty list
  • Loading branch information
be-marc authored Aug 23, 2024
1 parent 1e0342e commit 3f69adb
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 5 deletions.
3 changes: 2 additions & 1 deletion R/ResampleResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ ResampleResult = R6Class("ResampleResult",
#' If you calculate the performance on this prediction object directly, this is called micro averaging.
#'
#' @param predict_sets (`character()`)\cr
#' @return [Prediction].
#' Subset of `{"train", "test"}`.
#' @return [Prediction] or empty `list()` if no predictions are available.
prediction = function(predict_sets = "test") {
private$.data$prediction(private$.view, predict_sets)
},
Expand All @@ -113,6 +113,7 @@ ResampleResult = R6Class("ResampleResult",
#' @param predict_sets (`character()`)\cr
#' Subset of `{"train", "test", "internal_valid"}`.
#' @return List of [Prediction] objects, one per element in `predict_sets`.
#' Or list of empty `list()`s if no predictions are available.
predictions = function(predict_sets = "test") {
assert_subset(predict_sets, mlr_reflections$predict_sets, empty.ok = FALSE)
private$.data$predictions(private$.view, predict_sets)
Expand Down
2 changes: 1 addition & 1 deletion R/as_prediction.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ as_predictions = function(x, predict_sets = "test", ...) {
#' @rdname as_prediction
#' @export
as_predictions.list = function(x, predict_sets = "test", ...) { # nolint
result = vector("list", length(x))
result = replicate(length(x), list())
ii = lengths(x) > 0L
result[ii] = map(x[ii], function(li) {
assert_list(li, "PredictionData")
Expand Down
7 changes: 4 additions & 3 deletions man/ResampleResult.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 32 additions & 0 deletions tests/testthat/test_resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -450,3 +450,35 @@ test_that("multiple named measures", {
expect_numeric(res[["classif.acc"]])
expect_numeric(res[["classif.ce"]])
})

test_that("resample result works with not predicted predict set", {
learner = lrn("classif.debug", predict_sets = "train")
task = tsk("iris")
resampling = rsmp("holdout")

rr = resample(task, learner, resampling)

expect_list(rr$prediction(predict_sets = "test"), len = 0)
expect_list(rr$predictions(predict_sets = "test"), len = 1)
expect_list(rr$predictions(predict_sets = "test")[[1L]], len = 0)

tab = as.data.table(rr)
expect_list(tab$prediction, len = 1)
expect_list(tab$prediction[[1]], len = 0)
})

test_that("resample results works with no predicted predict set", {
learner = lrn("classif.debug", predict_sets = NULL)
task = tsk("iris")
resampling = rsmp("holdout")

rr = resample(task, learner, resampling)

expect_list(rr$prediction(predict_sets = "test"), len = 0)
expect_list(rr$predictions(predict_sets = "test"), len = 1)
expect_list(rr$predictions(predict_sets = "test")[[1L]], len = 0)

tab = as.data.table(rr)
expect_list(tab$prediction, len = 1)
expect_list(tab$prediction[[1]], len = 0)
})

0 comments on commit 3f69adb

Please sign in to comment.