From 3f69adbcc475a128162dcbfe9b0c855a55b0184c Mon Sep 17 00:00:00 2001 From: Marc Becker <33069354+be-marc@users.noreply.github.com> Date: Fri, 23 Aug 2024 16:39:51 +0200 Subject: [PATCH] refactor: returned predictions with different sets (#1113) * refactor: returned predictions with different sets * docs: empty list --- R/ResampleResult.R | 3 ++- R/as_prediction.R | 2 +- man/ResampleResult.Rd | 7 ++++--- tests/testthat/test_resample.R | 32 ++++++++++++++++++++++++++++++++ 4 files changed, 39 insertions(+), 5 deletions(-) diff --git a/R/ResampleResult.R b/R/ResampleResult.R index 5097c523d..83b79ada3 100644 --- a/R/ResampleResult.R +++ b/R/ResampleResult.R @@ -96,8 +96,8 @@ ResampleResult = R6Class("ResampleResult", #' If you calculate the performance on this prediction object directly, this is called micro averaging. #' #' @param predict_sets (`character()`)\cr - #' @return [Prediction]. #' Subset of `{"train", "test"}`. + #' @return [Prediction] or empty `list()` if no predictions are available. prediction = function(predict_sets = "test") { private$.data$prediction(private$.view, predict_sets) }, @@ -113,6 +113,7 @@ ResampleResult = R6Class("ResampleResult", #' @param predict_sets (`character()`)\cr #' Subset of `{"train", "test", "internal_valid"}`. #' @return List of [Prediction] objects, one per element in `predict_sets`. + #' Or list of empty `list()`s if no predictions are available. predictions = function(predict_sets = "test") { assert_subset(predict_sets, mlr_reflections$predict_sets, empty.ok = FALSE) private$.data$predictions(private$.view, predict_sets) diff --git a/R/as_prediction.R b/R/as_prediction.R index cab59dc8d..a5ab76a90 100644 --- a/R/as_prediction.R +++ b/R/as_prediction.R @@ -42,7 +42,7 @@ as_predictions = function(x, predict_sets = "test", ...) { #' @rdname as_prediction #' @export as_predictions.list = function(x, predict_sets = "test", ...) { # nolint - result = vector("list", length(x)) + result = replicate(length(x), list()) ii = lengths(x) > 0L result[ii] = map(x[ii], function(li) { assert_list(li, "PredictionData") diff --git a/man/ResampleResult.Rd b/man/ResampleResult.Rd index de18cdca3..8ea3c7d18 100644 --- a/man/ResampleResult.Rd +++ b/man/ResampleResult.Rd @@ -198,13 +198,13 @@ If you calculate the performance on this prediction object directly, this is cal \subsection{Arguments}{ \if{html}{\out{