diff --git a/R/ResampleResult.R b/R/ResampleResult.R index 5097c523d..83b79ada3 100644 --- a/R/ResampleResult.R +++ b/R/ResampleResult.R @@ -96,8 +96,8 @@ ResampleResult = R6Class("ResampleResult", #' If you calculate the performance on this prediction object directly, this is called micro averaging. #' #' @param predict_sets (`character()`)\cr - #' @return [Prediction]. #' Subset of `{"train", "test"}`. + #' @return [Prediction] or empty `list()` if no predictions are available. prediction = function(predict_sets = "test") { private$.data$prediction(private$.view, predict_sets) }, @@ -113,6 +113,7 @@ ResampleResult = R6Class("ResampleResult", #' @param predict_sets (`character()`)\cr #' Subset of `{"train", "test", "internal_valid"}`. #' @return List of [Prediction] objects, one per element in `predict_sets`. + #' Or list of empty `list()`s if no predictions are available. predictions = function(predict_sets = "test") { assert_subset(predict_sets, mlr_reflections$predict_sets, empty.ok = FALSE) private$.data$predictions(private$.view, predict_sets) diff --git a/R/as_prediction.R b/R/as_prediction.R index cab59dc8d..a5ab76a90 100644 --- a/R/as_prediction.R +++ b/R/as_prediction.R @@ -42,7 +42,7 @@ as_predictions = function(x, predict_sets = "test", ...) { #' @rdname as_prediction #' @export as_predictions.list = function(x, predict_sets = "test", ...) { # nolint - result = vector("list", length(x)) + result = replicate(length(x), list()) ii = lengths(x) > 0L result[ii] = map(x[ii], function(li) { assert_list(li, "PredictionData") diff --git a/man/ResampleResult.Rd b/man/ResampleResult.Rd index de18cdca3..8ea3c7d18 100644 --- a/man/ResampleResult.Rd +++ b/man/ResampleResult.Rd @@ -198,13 +198,13 @@ If you calculate the performance on this prediction object directly, this is cal \subsection{Arguments}{ \if{html}{\out{
}} \describe{ -\item{\code{predict_sets}}{(\code{character()})\cr} +\item{\code{predict_sets}}{(\code{character()})\cr +Subset of \verb{\{"train", "test"\}}.} } \if{html}{\out{
}} } \subsection{Returns}{ -\link{Prediction}. -Subset of \verb{\{"train", "test"\}}. +\link{Prediction} or empty \code{list()} if no predictions are available. } } \if{html}{\out{
}} @@ -231,6 +231,7 @@ Subset of \verb{\{"train", "test", "internal_valid"\}}.} } \subsection{Returns}{ List of \link{Prediction} objects, one per element in \code{predict_sets}. +Or list of empty \code{list()}s if no predictions are available. } } \if{html}{\out{
}} diff --git a/tests/testthat/test_resample.R b/tests/testthat/test_resample.R index 1473c7cf6..1efcd09ef 100644 --- a/tests/testthat/test_resample.R +++ b/tests/testthat/test_resample.R @@ -450,3 +450,35 @@ test_that("multiple named measures", { expect_numeric(res[["classif.acc"]]) expect_numeric(res[["classif.ce"]]) }) + +test_that("resample result works with not predicted predict set", { + learner = lrn("classif.debug", predict_sets = "train") + task = tsk("iris") + resampling = rsmp("holdout") + + rr = resample(task, learner, resampling) + + expect_list(rr$prediction(predict_sets = "test"), len = 0) + expect_list(rr$predictions(predict_sets = "test"), len = 1) + expect_list(rr$predictions(predict_sets = "test")[[1L]], len = 0) + + tab = as.data.table(rr) + expect_list(tab$prediction, len = 1) + expect_list(tab$prediction[[1]], len = 0) +}) + +test_that("resample results works with no predicted predict set", { + learner = lrn("classif.debug", predict_sets = NULL) + task = tsk("iris") + resampling = rsmp("holdout") + + rr = resample(task, learner, resampling) + + expect_list(rr$prediction(predict_sets = "test"), len = 0) + expect_list(rr$predictions(predict_sets = "test"), len = 1) + expect_list(rr$predictions(predict_sets = "test")[[1L]], len = 0) + + tab = as.data.table(rr) + expect_list(tab$prediction, len = 1) + expect_list(tab$prediction[[1]], len = 0) +})