diff --git a/R/MeasureRegrRSQ.R b/R/MeasureRegrRSQ.R index 8019b1d90..5fbd1e2e4 100644 --- a/R/MeasureRegrRSQ.R +++ b/R/MeasureRegrRSQ.R @@ -29,7 +29,7 @@ #' @template seealso_measure #' @export MeasureRegrRSQ = R6Class("MeasureRSQ", - inherit = Measure, + inherit = MeasureRegr, public = list( #' @description #' Creates a new instance of this [R6][R6::R6Class] class. @@ -40,10 +40,10 @@ MeasureRegrRSQ = R6Class("MeasureRSQ", super$initialize( id = "rsq", - task_type = "regr", properties = if (!private$.pred_set_mean) c("requires_task", "requires_train_set") else character(0), predict_type = "response", minimize = FALSE, + range = c(-Inf, 1), man = "mlr3::mlr_measures_regr.rsq" ) } diff --git a/R/as_prediction.R b/R/as_prediction.R index a5ab76a90..0b87e1b1b 100644 --- a/R/as_prediction.R +++ b/R/as_prediction.R @@ -8,6 +8,8 @@ #' @return [Prediction]. #' @export as_prediction = function(x, check = FALSE, ...) { + if (is.null(x)) return(list()) + UseMethod("as_prediction") } diff --git a/R/worker.R b/R/worker.R index 3ee363ebc..f147adbeb 100644 --- a/R/worker.R +++ b/R/worker.R @@ -326,6 +326,10 @@ workhorse = function(iteration, task, learner, resampling, param_values = NULL, lg$debug("Creating Prediction for predict set '%s'", set) learner_predict(learner, task, row_ids) }, set = predict_sets, row_ids = pred_data$sets, task = pred_data$tasks) + + if (!length(predict_sets)) { + learner$state$predict_time = 0L + } pdatas = discard(pdatas, is.null) # set the model slot after prediction so it can be sent back to the main process diff --git a/man/mlr_measures_regr.rsq.Rd b/man/mlr_measures_regr.rsq.Rd index 8798a5594..5cfb3295a 100644 --- a/man/mlr_measures_regr.rsq.Rd +++ b/man/mlr_measures_regr.rsq.Rd @@ -33,7 +33,7 @@ msr("regr.rsq") \itemize{ \item Task type: \dQuote{regr} -\item Range: \eqn{(-\infty, \infty)}{(-Inf, Inf)} +\item Range: \eqn{(-\infty, 1]}{(-Inf, 1]} \item Minimize: FALSE \item Average: macro \item Required Prediction: \dQuote{response} @@ -76,8 +76,8 @@ Other Measure: \code{\link{mlr_measures_selected_features}} } \concept{Measure} -\section{Super class}{ -\code{\link[mlr3:Measure]{mlr3::Measure}} -> \code{MeasureRSQ} +\section{Super classes}{ +\code{\link[mlr3:Measure]{mlr3::Measure}} -> \code{\link[mlr3:MeasureRegr]{mlr3::MeasureRegr}} -> \code{MeasureRSQ} } \section{Methods}{ \subsection{Public methods}{ diff --git a/tests/testthat/test_benchmark.R b/tests/testthat/test_benchmark.R index 6e73c799e..e83c9dcbf 100644 --- a/tests/testthat/test_benchmark.R +++ b/tests/testthat/test_benchmark.R @@ -567,3 +567,16 @@ test_that("predictions retrieved with as.data.table and predictions method are e predictions = unlist(map(bmr$resample_results$resample_result, function(rr) rr$predictions(predict_sets = "train")), recursive = FALSE) expect_equal(tab$prediction, predictions) }) + +test_that("score works with predictions and empty predictions", { + learner_1 = lrn("classif.rpart", predict_sets = "train", id = "learner_1") + learner_2 = lrn("classif.rpart", predict_sets = "test", id = "learner_2") + task = tsk("pima") + + design = benchmark_grid(task, list(learner_1, learner_2), rsmp("holdout")) + + bmr = benchmark(design) + + expect_warning({tab = bmr$score(msr("classif.ce", predict_sets = "test"))}, "Measure") + expect_equal(tab$classif.ce[1], NaN) +}) diff --git a/tests/testthat/test_resample.R b/tests/testthat/test_resample.R index 35a2a5afa..9896f8b48 100644 --- a/tests/testthat/test_resample.R +++ b/tests/testthat/test_resample.R @@ -473,6 +473,9 @@ test_that("resample result works with not predicted predict set", { tab = as.data.table(rr) expect_list(tab$prediction, len = 1) expect_list(tab$prediction[[1]], len = 0) + + expect_warning({tab = rr$score(msr("classif.ce", predict_sets = "test"))}, "Measure") + expect_equal(tab$classif.ce, NaN) }) test_that("resample results works with no predicted predict set", { @@ -489,4 +492,14 @@ test_that("resample results works with no predicted predict set", { tab = as.data.table(rr) expect_list(tab$prediction, len = 1) expect_list(tab$prediction[[1]], len = 0) + + expect_warning({tab = rr$score(msr("classif.ce", predict_sets = "test"))}, "Measure") + expect_equal(tab$classif.ce, NaN) +}) + +test_that("predict_time is 0 if no predict_set is specified", { + learner = lrn("classif.featureless", predict_sets = NULL) + rr = resample(task, learner, resampling) + times = rr$score(msr("time_predict"))$time_predict + expect_true(all(times == 0)) })