From 27531f6036372053584541109fa65c03df8c0a82 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Fri, 23 Aug 2024 17:22:28 +0200 Subject: [PATCH] feat(measure): allow empty predict_sets (#1100) * feat(measure): allow empty predict_sets * ... * cleanup * cleanup cleanup * add 'requires_no_prediction' property * Update R/Measure.R * tests: predict set * tests: empty list * test: dont use default --------- Co-authored-by: be-marc --- NEWS.md | 5 +++- R/Measure.R | 32 ++++++++++++++++++-------- R/MeasureAIC.R | 3 ++- R/MeasureBIC.R | 3 ++- R/MeasureElapsedTime.R | 3 ++- R/MeasureInternalValidScore.R | 3 ++- R/MeasureOOBError.R | 3 ++- R/MeasureSelectedFeatures.R | 3 ++- R/assertions.R | 3 ++- R/mlr_reflections.R | 2 +- inst/testthat/helper_expectations.R | 4 ++++ man-roxygen/param_measure_properties.R | 2 ++ man/Measure.Rd | 24 ++++++++++--------- man/MeasureClassif.Rd | 2 ++ man/MeasureRegr.Rd | 2 ++ man/MeasureSimilarity.Rd | 2 ++ man/mlr_assertions.Rd | 2 +- tests/testthat/test_Measure.R | 15 ++++++++++++ tests/testthat/test_resample.R | 8 +++++++ 19 files changed, 91 insertions(+), 30 deletions(-) diff --git a/NEWS.md b/NEWS.md index 32d310e78..797229fdf 100644 --- a/NEWS.md +++ b/NEWS.md @@ -9,7 +9,10 @@ * fix: column info is now checked for compatibility during `Learner$predict` (#943). * BREAKING CHANGE: the predict time of the learner now stores the cumulative duration for all predict sets (#992). * feat: `$internal_valid_task` can now be set to an `integer` vector. -* refactor: Deprecated the `$divide()` method. +* feat: Measures can now have an empty `$predict_sets` (#1094). + this is relevant for measures that only extract information from + the model of a learner (such as internal validation scores or AIC / BIC) +* refactor: Deprecated the `$divide()` method * fix: `Task$cbind()` now works with non-standard primary keys for `data.frames` (#961). * fix: Triggering of fallback learner now has log-level `"info"` instead of `"debug"` (#972). * feat: Added new measure `pinballs `. diff --git a/R/Measure.R b/R/Measure.R index 83ea6ae75..649dd077c 100644 --- a/R/Measure.R +++ b/R/Measure.R @@ -76,9 +76,6 @@ Measure = R6Class("Measure", #' Required predict type of the [Learner]. predict_type = NULL, - #' @template field_predict_sets - predict_sets = NULL, - #' @field check_prerequisites (`character(1)`)\cr #' How to proceed if one of the following prerequisites is not met: #' @@ -120,6 +117,7 @@ Measure = R6Class("Measure", predict_sets = "test", task_properties = character(), packages = character(), label = NA_character_, man = NA_character_, trafo = NULL) { + self$properties = unique(properties) self$id = assert_string(id, min.chars = 1L) self$label = assert_string(label, na.ok = TRUE) self$task_type = task_type @@ -142,9 +140,8 @@ Measure = R6Class("Measure", assert_subset(task_properties, mlr_reflections$task_properties[[task_type]]) } - self$properties = unique(properties) self$predict_type = predict_type - self$predict_sets = assert_subset(predict_sets, mlr_reflections$predict_sets, empty.ok = FALSE) + self$predict_sets = predict_sets self$task_properties = task_properties self$packages = union("mlr3", assert_character(packages, any.missing = FALSE, min.chars = 1L)) self$man = assert_string(man, na.ok = TRUE) @@ -198,8 +195,7 @@ Measure = R6Class("Measure", #' @return `numeric(1)`. score = function(prediction, task = NULL, learner = NULL, train_set = NULL) { assert_measure(self, task = task, learner = learner) - assert_prediction(prediction) - + assert_prediction(prediction, null.ok = "requires_no_prediction" %nin% self$properties) if ("requires_task" %in% self$properties && is.null(task)) { stopf("Measure '%s' requires a task", self$id) @@ -263,6 +259,14 @@ Measure = R6Class("Measure", ), active = list( + #' @template field_predict_sets + predict_sets = function(rhs) { + if (!missing(rhs)) { + private$.predict_sets = assert_subset(rhs, mlr_reflections$predict_sets, empty.ok = "requires_no_prediction" %in% self$properties) + } + private$.predict_sets + }, + #' @template field_hash hash = function(rhs) { assert_ro_binding(rhs) @@ -303,6 +307,7 @@ Measure = R6Class("Measure", ), private = list( + .predict_sets = NULL, .extra_hash = character(), .average = NULL, .aggregator = NULL @@ -326,8 +331,11 @@ Measure = R6Class("Measure", #' @return (`numeric()`). #' @noRd score_single_measure = function(measure, task, learner, train_set, prediction) { - if (is.null(prediction)) { - return(NaN) + if (!length(measure$predict_sets)) { + score = get_private(measure)$.score( + prediction = NULL, task = task, learner = learner, train_set = train_set + ) + return(score) } # merge multiple predictions (on different predict sets) to a single one @@ -343,6 +351,12 @@ score_single_measure = function(measure, task, learner, train_set, prediction) { # convert pdata to regular prediction prediction = as_prediction(prediction, check = FALSE) + if (is.null(prediction) && length(measure$predict_sets)) { + return(NaN) + } + + + if (!is_scalar_na(measure$predict_type) && measure$predict_type %nin% prediction$predict_types) { # TODO lgr$debug() return(NaN) diff --git a/R/MeasureAIC.R b/R/MeasureAIC.R index 2809bfa64..5a0203d1d 100644 --- a/R/MeasureAIC.R +++ b/R/MeasureAIC.R @@ -27,7 +27,8 @@ MeasureAIC = R6Class("MeasureAIC", id = "aic", task_type = NA_character_, param_set = param_set, - properties = c("na_score", "requires_learner", "requires_model"), + predict_sets = NULL, + properties = c("na_score", "requires_learner", "requires_model", "requires_no_prediction"), predict_type = NA_character_, minimize = TRUE, label = "Akaike Information Criterion", diff --git a/R/MeasureBIC.R b/R/MeasureBIC.R index eb9480b1c..94d67d515 100644 --- a/R/MeasureBIC.R +++ b/R/MeasureBIC.R @@ -25,7 +25,8 @@ MeasureBIC = R6Class("MeasureBIC", super$initialize( id = "bic", task_type = NA_character_, - properties = c("na_score", "requires_learner", "requires_model"), + properties = c("na_score", "requires_learner", "requires_model", "requires_no_prediction"), + predict_sets = NULL, predict_type = NA_character_, minimize = TRUE, label = "Bayesian Information Criterion", diff --git a/R/MeasureElapsedTime.R b/R/MeasureElapsedTime.R index 2fc583078..b2c18417d 100644 --- a/R/MeasureElapsedTime.R +++ b/R/MeasureElapsedTime.R @@ -41,10 +41,11 @@ MeasureElapsedTime = R6Class("MeasureElapsedTime", super$initialize( id = id, task_type = NA_character_, + predict_sets = NULL, predict_type = NA_character_, range = c(0, Inf), minimize = TRUE, - properties = "requires_learner", + properties = c("requires_learner", "requires_no_prediction"), label = "Elapsed Time", man = "mlr3::mlr_measures_elapsed_time" ) diff --git a/R/MeasureInternalValidScore.R b/R/MeasureInternalValidScore.R index fff6f7a85..2d60bf3e0 100644 --- a/R/MeasureInternalValidScore.R +++ b/R/MeasureInternalValidScore.R @@ -33,7 +33,8 @@ MeasureInternalValidScore = R6Class("MeasureInternalValidScore", super$initialize( id = select %??% "internal_valid_score", task_type = NA_character_, - properties = c("na_score", "requires_learner"), + properties = c("na_score", "requires_model", "requires_learner", "requires_no_prediction"), + predict_sets = NULL, predict_type = NA_character_, range = c(-Inf, Inf), minimize = assert_flag(minimize, na.ok = TRUE), diff --git a/R/MeasureOOBError.R b/R/MeasureOOBError.R index bc6b1c652..ab4827087 100644 --- a/R/MeasureOOBError.R +++ b/R/MeasureOOBError.R @@ -22,7 +22,8 @@ MeasureOOBError = R6Class("MeasureOOBError", super$initialize( id = "oob_error", task_type = NA_character_, - properties = c("na_score", "requires_learner"), + properties = c("na_score", "requires_learner", "requires_no_prediction"), + predict_sets = NULL, predict_type = NA_character_, range = c(-Inf, Inf), minimize = TRUE, diff --git a/R/MeasureSelectedFeatures.R b/R/MeasureSelectedFeatures.R index 1b2891bf7..db4748608 100644 --- a/R/MeasureSelectedFeatures.R +++ b/R/MeasureSelectedFeatures.R @@ -37,7 +37,8 @@ MeasureSelectedFeatures = R6Class("MeasureSelectedFeatures", id = "selected_features", param_set = param_set, task_type = NA_character_, - properties = c("requires_task", "requires_learner", "requires_model"), + properties = c("requires_task", "requires_learner", "requires_model", "requires_no_prediction"), + predict_sets = NULL, predict_type = NA_character_, range = c(0, Inf), minimize = TRUE, diff --git a/R/assertions.R b/R/assertions.R index 576431b79..e20a95184 100644 --- a/R/assertions.R +++ b/R/assertions.R @@ -281,7 +281,8 @@ assert_resamplings = function(resamplings, instantiated = NULL, .var.name = vnam #' @export #' @param prediction ([Prediction]). #' @rdname mlr_assertions -assert_prediction = function(prediction, .var.name = vname(prediction)) { +assert_prediction = function(prediction, .var.name = vname(prediction), null.ok = FALSE) { + if (null.ok && is.null(prediction)) return(prediction) assert_class(prediction, "Prediction", .var.name = .var.name) } diff --git a/R/mlr_reflections.R b/R/mlr_reflections.R index 782bfd060..f103871b4 100644 --- a/R/mlr_reflections.R +++ b/R/mlr_reflections.R @@ -140,7 +140,7 @@ local({ ### Measures - tmp = c("na_score", "requires_task", "requires_learner", "requires_model", "requires_train_set", "primary_iters") + tmp = c("na_score", "requires_task", "requires_learner", "requires_model", "requires_train_set", "primary_iters", "requires_no_prediction") mlr_reflections$measure_properties = list( classif = tmp, regr = tmp diff --git a/inst/testthat/helper_expectations.R b/inst/testthat/helper_expectations.R index e848d156b..91bdee35e 100644 --- a/inst/testthat/helper_expectations.R +++ b/inst/testthat/helper_expectations.R @@ -517,6 +517,10 @@ expect_measure = function(m) { expect_man_exists(m$man) testthat::expect_output(print(m), "Measure") + if ("requires_no_prediction" %in% m$properties) { + testthat::expect_true(is.null(m$predict_sets)) + } + expect_id(m$id) checkmate::expect_subset(m$task_type, c(NA_character_, mlr3::mlr_reflections$task_types$type), empty.ok = FALSE) checkmate::expect_numeric(m$range, len = 2, any.missing = FALSE) diff --git a/man-roxygen/param_measure_properties.R b/man-roxygen/param_measure_properties.R index 139adf0a5..103439956 100644 --- a/man-roxygen/param_measure_properties.R +++ b/man-roxygen/param_measure_properties.R @@ -10,3 +10,5 @@ #' * `"na_score"` (the measure is expected to occasionally return `NA` or `NaN`). #' * `"primary_iters"` (the measure explictly handles resamplings that only use a subset #' of their iterations for the point estimate). +#' * `"requires_no_prediction"` (No prediction is required; This usually means that the +#' measure extracts some information from the learner state.). diff --git a/man/Measure.Rd b/man/Measure.Rd index b8da38475..f52bc580d 100644 --- a/man/Measure.Rd +++ b/man/Measure.Rd @@ -93,17 +93,6 @@ observation-wise losses (e.g. \code{sqrt} for RMSE) \item{\code{predict_type}}{(\code{character(1)})\cr Required predict type of the \link{Learner}.} -\item{\code{predict_sets}}{(\code{character()})\cr -During \code{\link[=resample]{resample()}}/\code{\link[=benchmark]{benchmark()}}, a \link{Learner} can predict on multiple sets. -Per default, a learner only predicts observations in the test set (\code{predict_sets == "test"}). -To change this behavior, set \code{predict_sets} to a non-empty subset of \verb{\{"train", "test", "internal_valid"\}}. -The \code{"train"} predict set contains the train ids from the resampling. This means that if a learner does validation and -sets \verb{$validate} to a ratio (creating the validation data from the training data), the train predictions -will include the predictions for the validation data. -Each set yields a separate \link{Prediction} object. -Those can be combined via getters in \link{ResampleResult}/\link{BenchmarkResult}, or \link{Measure}s can be configured -to operate on specific subsets of the calculated prediction sets.} - \item{\code{check_prerequisites}}{(\code{character(1)})\cr How to proceed if one of the following prerequisites is not met: \itemize{ @@ -139,6 +128,17 @@ Defaults to \code{NA}, but can be set by child classes.} \section{Active bindings}{ \if{html}{\out{
}} \describe{ +\item{\code{predict_sets}}{(\code{character()})\cr +During \code{\link[=resample]{resample()}}/\code{\link[=benchmark]{benchmark()}}, a \link{Learner} can predict on multiple sets. +Per default, a learner only predicts observations in the test set (\code{predict_sets == "test"}). +To change this behavior, set \code{predict_sets} to a non-empty subset of \verb{\{"train", "test", "internal_valid"\}}. +The \code{"train"} predict set contains the train ids from the resampling. This means that if a learner does validation and +sets \verb{$validate} to a ratio (creating the validation data from the training data), the train predictions +will include the predictions for the validation data. +Each set yields a separate \link{Prediction} object. +Those can be combined via getters in \link{ResampleResult}/\link{BenchmarkResult}, or \link{Measure}s can be configured +to operate on specific subsets of the calculated prediction sets.} + \item{\code{hash}}{(\code{character(1)})\cr Hash (unique identifier) for this object.} @@ -260,6 +260,8 @@ model), \item \code{"na_score"} (the measure is expected to occasionally return \code{NA} or \code{NaN}). \item \code{"primary_iters"} (the measure explictly handles resamplings that only use a subset of their iterations for the point estimate). +\item \code{"requires_no_prediction"} (No prediction is required; This usually means that the +measure extracts some information from the learner state.). }} \item{\code{predict_type}}{(\code{character(1)})\cr diff --git a/man/MeasureClassif.Rd b/man/MeasureClassif.Rd index b2dfdc477..18aaa7fdd 100644 --- a/man/MeasureClassif.Rd +++ b/man/MeasureClassif.Rd @@ -139,6 +139,8 @@ model), \item \code{"na_score"} (the measure is expected to occasionally return \code{NA} or \code{NaN}). \item \code{"primary_iters"} (the measure explictly handles resamplings that only use a subset of their iterations for the point estimate). +\item \code{"requires_no_prediction"} (No prediction is required; This usually means that the +measure extracts some information from the learner state.). }} \item{\code{predict_type}}{(\code{character(1)})\cr diff --git a/man/MeasureRegr.Rd b/man/MeasureRegr.Rd index 40f02a2c8..8c432bba7 100644 --- a/man/MeasureRegr.Rd +++ b/man/MeasureRegr.Rd @@ -139,6 +139,8 @@ model), \item \code{"na_score"} (the measure is expected to occasionally return \code{NA} or \code{NaN}). \item \code{"primary_iters"} (the measure explictly handles resamplings that only use a subset of their iterations for the point estimate). +\item \code{"requires_no_prediction"} (No prediction is required; This usually means that the +measure extracts some information from the learner state.). }} \item{\code{predict_type}}{(\code{character(1)})\cr diff --git a/man/MeasureSimilarity.Rd b/man/MeasureSimilarity.Rd index 6325f6399..df2b7d40e 100644 --- a/man/MeasureSimilarity.Rd +++ b/man/MeasureSimilarity.Rd @@ -153,6 +153,8 @@ model), \item \code{"na_score"} (the measure is expected to occasionally return \code{NA} or \code{NaN}). \item \code{"primary_iters"} (the measure explictly handles resamplings that only use a subset of their iterations for the point estimate). +\item \code{"requires_no_prediction"} (No prediction is required; This usually means that the +measure extracts some information from the learner state.). }} \item{\code{predict_type}}{(\code{character(1)})\cr diff --git a/man/mlr_assertions.Rd b/man/mlr_assertions.Rd index 085d8afae..008bdbf10 100644 --- a/man/mlr_assertions.Rd +++ b/man/mlr_assertions.Rd @@ -84,7 +84,7 @@ assert_resamplings( .var.name = vname(resamplings) ) -assert_prediction(prediction, .var.name = vname(prediction)) +assert_prediction(prediction, .var.name = vname(prediction), null.ok = FALSE) assert_resample_result(rr, .var.name = vname(rr)) diff --git a/tests/testthat/test_Measure.R b/tests/testthat/test_Measure.R index c745e6c54..786c94b84 100644 --- a/tests/testthat/test_Measure.R +++ b/tests/testthat/test_Measure.R @@ -173,3 +173,18 @@ test_that("primary iters are respected", { x2 = rr3$aggregate(jaccard) expect_equal(x1, x2) }) + +test_that("no predict_sets required (#1094)", { + m = msr("internal_valid_score") + expect_equal(m$predict_sets, NULL) + rr = resample(tsk("iris"), lrn("classif.debug", validate = 0.3, predict_sets = NULL), rsmp("holdout")) + expect_double(rr$aggregate(m)) + expect_warning(rr$aggregate(msr("classif.ce")), "needs predict sets") +}) + +test_that("checks on predict_sets", { + m = msr("classif.ce") + expect_error({m$predict_sets = NULL}, "Must be a subset") + expect_error({m$predict_sets = "imaginary"}, "Must be a subset") +}) + diff --git a/tests/testthat/test_resample.R b/tests/testthat/test_resample.R index 1efcd09ef..35a2a5afa 100644 --- a/tests/testthat/test_resample.R +++ b/tests/testthat/test_resample.R @@ -451,6 +451,14 @@ test_that("multiple named measures", { expect_numeric(res[["classif.ce"]]) }) +test_that("empty predictions", { + rr = resample(tsk("iris"), lrn("classif.debug", validate = 0.3, predict_sets = NULL), rsmp("holdout")) + preds = rr$predictions() + expect_equal(preds, list(list())) + pred = rr$prediction() + expect_equal(pred, list()) +}) + test_that("resample result works with not predicted predict set", { learner = lrn("classif.debug", predict_sets = "train") task = tsk("iris")