Skip to content

Commit

Permalink
feat(measure): allow empty predict_sets (#1100)
Browse files Browse the repository at this point in the history
* feat(measure): allow empty predict_sets

* ...

* cleanup

* cleanup cleanup

* add 'requires_no_prediction' property

* Update R/Measure.R

* tests: predict set

* tests: empty list

* test: dont use default

---------

Co-authored-by: be-marc <[email protected]>
  • Loading branch information
sebffischer and be-marc authored Aug 23, 2024
1 parent 3f69adb commit 27531f6
Show file tree
Hide file tree
Showing 19 changed files with 91 additions and 30 deletions.
5 changes: 4 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
* fix: column info is now checked for compatibility during `Learner$predict` (#943).
* BREAKING CHANGE: the predict time of the learner now stores the cumulative duration for all predict sets (#992).
* feat: `$internal_valid_task` can now be set to an `integer` vector.
* refactor: Deprecated the `$divide()` method.
* feat: Measures can now have an empty `$predict_sets` (#1094).
this is relevant for measures that only extract information from
the model of a learner (such as internal validation scores or AIC / BIC)
* refactor: Deprecated the `$divide()` method
* fix: `Task$cbind()` now works with non-standard primary keys for `data.frames` (#961).
* fix: Triggering of fallback learner now has log-level `"info"` instead of `"debug"` (#972).
* feat: Added new measure `pinballs `.
Expand Down
32 changes: 23 additions & 9 deletions R/Measure.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,6 @@ Measure = R6Class("Measure",
#' Required predict type of the [Learner].
predict_type = NULL,

#' @template field_predict_sets
predict_sets = NULL,

#' @field check_prerequisites (`character(1)`)\cr
#' How to proceed if one of the following prerequisites is not met:
#'
Expand Down Expand Up @@ -120,6 +117,7 @@ Measure = R6Class("Measure",
predict_sets = "test", task_properties = character(), packages = character(),
label = NA_character_, man = NA_character_, trafo = NULL) {

self$properties = unique(properties)
self$id = assert_string(id, min.chars = 1L)
self$label = assert_string(label, na.ok = TRUE)
self$task_type = task_type
Expand All @@ -142,9 +140,8 @@ Measure = R6Class("Measure",
assert_subset(task_properties, mlr_reflections$task_properties[[task_type]])
}

self$properties = unique(properties)
self$predict_type = predict_type
self$predict_sets = assert_subset(predict_sets, mlr_reflections$predict_sets, empty.ok = FALSE)
self$predict_sets = predict_sets
self$task_properties = task_properties
self$packages = union("mlr3", assert_character(packages, any.missing = FALSE, min.chars = 1L))
self$man = assert_string(man, na.ok = TRUE)
Expand Down Expand Up @@ -198,8 +195,7 @@ Measure = R6Class("Measure",
#' @return `numeric(1)`.
score = function(prediction, task = NULL, learner = NULL, train_set = NULL) {
assert_measure(self, task = task, learner = learner)
assert_prediction(prediction)

assert_prediction(prediction, null.ok = "requires_no_prediction" %nin% self$properties)

if ("requires_task" %in% self$properties && is.null(task)) {
stopf("Measure '%s' requires a task", self$id)
Expand Down Expand Up @@ -263,6 +259,14 @@ Measure = R6Class("Measure",
),

active = list(
#' @template field_predict_sets
predict_sets = function(rhs) {
if (!missing(rhs)) {
private$.predict_sets = assert_subset(rhs, mlr_reflections$predict_sets, empty.ok = "requires_no_prediction" %in% self$properties)
}
private$.predict_sets
},

#' @template field_hash
hash = function(rhs) {
assert_ro_binding(rhs)
Expand Down Expand Up @@ -303,6 +307,7 @@ Measure = R6Class("Measure",
),

private = list(
.predict_sets = NULL,
.extra_hash = character(),
.average = NULL,
.aggregator = NULL
Expand All @@ -326,8 +331,11 @@ Measure = R6Class("Measure",
#' @return (`numeric()`).
#' @noRd
score_single_measure = function(measure, task, learner, train_set, prediction) {
if (is.null(prediction)) {
return(NaN)
if (!length(measure$predict_sets)) {
score = get_private(measure)$.score(
prediction = NULL, task = task, learner = learner, train_set = train_set
)
return(score)
}

# merge multiple predictions (on different predict sets) to a single one
Expand All @@ -343,6 +351,12 @@ score_single_measure = function(measure, task, learner, train_set, prediction) {
# convert pdata to regular prediction
prediction = as_prediction(prediction, check = FALSE)

if (is.null(prediction) && length(measure$predict_sets)) {
return(NaN)
}



if (!is_scalar_na(measure$predict_type) && measure$predict_type %nin% prediction$predict_types) {
# TODO lgr$debug()
return(NaN)
Expand Down
3 changes: 2 additions & 1 deletion R/MeasureAIC.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ MeasureAIC = R6Class("MeasureAIC",
id = "aic",
task_type = NA_character_,
param_set = param_set,
properties = c("na_score", "requires_learner", "requires_model"),
predict_sets = NULL,
properties = c("na_score", "requires_learner", "requires_model", "requires_no_prediction"),
predict_type = NA_character_,
minimize = TRUE,
label = "Akaike Information Criterion",
Expand Down
3 changes: 2 additions & 1 deletion R/MeasureBIC.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ MeasureBIC = R6Class("MeasureBIC",
super$initialize(
id = "bic",
task_type = NA_character_,
properties = c("na_score", "requires_learner", "requires_model"),
properties = c("na_score", "requires_learner", "requires_model", "requires_no_prediction"),
predict_sets = NULL,
predict_type = NA_character_,
minimize = TRUE,
label = "Bayesian Information Criterion",
Expand Down
3 changes: 2 additions & 1 deletion R/MeasureElapsedTime.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ MeasureElapsedTime = R6Class("MeasureElapsedTime",
super$initialize(
id = id,
task_type = NA_character_,
predict_sets = NULL,
predict_type = NA_character_,
range = c(0, Inf),
minimize = TRUE,
properties = "requires_learner",
properties = c("requires_learner", "requires_no_prediction"),
label = "Elapsed Time",
man = "mlr3::mlr_measures_elapsed_time"
)
Expand Down
3 changes: 2 additions & 1 deletion R/MeasureInternalValidScore.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ MeasureInternalValidScore = R6Class("MeasureInternalValidScore",
super$initialize(
id = select %??% "internal_valid_score",
task_type = NA_character_,
properties = c("na_score", "requires_learner"),
properties = c("na_score", "requires_model", "requires_learner", "requires_no_prediction"),
predict_sets = NULL,
predict_type = NA_character_,
range = c(-Inf, Inf),
minimize = assert_flag(minimize, na.ok = TRUE),
Expand Down
3 changes: 2 additions & 1 deletion R/MeasureOOBError.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ MeasureOOBError = R6Class("MeasureOOBError",
super$initialize(
id = "oob_error",
task_type = NA_character_,
properties = c("na_score", "requires_learner"),
properties = c("na_score", "requires_learner", "requires_no_prediction"),
predict_sets = NULL,
predict_type = NA_character_,
range = c(-Inf, Inf),
minimize = TRUE,
Expand Down
3 changes: 2 additions & 1 deletion R/MeasureSelectedFeatures.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ MeasureSelectedFeatures = R6Class("MeasureSelectedFeatures",
id = "selected_features",
param_set = param_set,
task_type = NA_character_,
properties = c("requires_task", "requires_learner", "requires_model"),
properties = c("requires_task", "requires_learner", "requires_model", "requires_no_prediction"),
predict_sets = NULL,
predict_type = NA_character_,
range = c(0, Inf),
minimize = TRUE,
Expand Down
3 changes: 2 additions & 1 deletion R/assertions.R
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,8 @@ assert_resamplings = function(resamplings, instantiated = NULL, .var.name = vnam
#' @export
#' @param prediction ([Prediction]).
#' @rdname mlr_assertions
assert_prediction = function(prediction, .var.name = vname(prediction)) {
assert_prediction = function(prediction, .var.name = vname(prediction), null.ok = FALSE) {
if (null.ok && is.null(prediction)) return(prediction)
assert_class(prediction, "Prediction", .var.name = .var.name)
}

Expand Down
2 changes: 1 addition & 1 deletion R/mlr_reflections.R
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ local({


### Measures
tmp = c("na_score", "requires_task", "requires_learner", "requires_model", "requires_train_set", "primary_iters")
tmp = c("na_score", "requires_task", "requires_learner", "requires_model", "requires_train_set", "primary_iters", "requires_no_prediction")
mlr_reflections$measure_properties = list(
classif = tmp,
regr = tmp
Expand Down
4 changes: 4 additions & 0 deletions inst/testthat/helper_expectations.R
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,10 @@ expect_measure = function(m) {
expect_man_exists(m$man)
testthat::expect_output(print(m), "Measure")

if ("requires_no_prediction" %in% m$properties) {
testthat::expect_true(is.null(m$predict_sets))
}

expect_id(m$id)
checkmate::expect_subset(m$task_type, c(NA_character_, mlr3::mlr_reflections$task_types$type), empty.ok = FALSE)
checkmate::expect_numeric(m$range, len = 2, any.missing = FALSE)
Expand Down
2 changes: 2 additions & 0 deletions man-roxygen/param_measure_properties.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@
#' * `"na_score"` (the measure is expected to occasionally return `NA` or `NaN`).
#' * `"primary_iters"` (the measure explictly handles resamplings that only use a subset
#' of their iterations for the point estimate).
#' * `"requires_no_prediction"` (No prediction is required; This usually means that the
#' measure extracts some information from the learner state.).
24 changes: 13 additions & 11 deletions man/Measure.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/MeasureClassif.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/MeasureRegr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/MeasureSimilarity.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mlr_assertions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions tests/testthat/test_Measure.R
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,18 @@ test_that("primary iters are respected", {
x2 = rr3$aggregate(jaccard)
expect_equal(x1, x2)
})

test_that("no predict_sets required (#1094)", {
m = msr("internal_valid_score")
expect_equal(m$predict_sets, NULL)
rr = resample(tsk("iris"), lrn("classif.debug", validate = 0.3, predict_sets = NULL), rsmp("holdout"))
expect_double(rr$aggregate(m))
expect_warning(rr$aggregate(msr("classif.ce")), "needs predict sets")
})

test_that("checks on predict_sets", {
m = msr("classif.ce")
expect_error({m$predict_sets = NULL}, "Must be a subset")
expect_error({m$predict_sets = "imaginary"}, "Must be a subset")
})

8 changes: 8 additions & 0 deletions tests/testthat/test_resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,14 @@ test_that("multiple named measures", {
expect_numeric(res[["classif.ce"]])
})

test_that("empty predictions", {
rr = resample(tsk("iris"), lrn("classif.debug", validate = 0.3, predict_sets = NULL), rsmp("holdout"))
preds = rr$predictions()
expect_equal(preds, list(list()))
pred = rr$prediction()
expect_equal(pred, list())
})

test_that("resample result works with not predicted predict set", {
learner = lrn("classif.debug", predict_sets = "train")
task = tsk("iris")
Expand Down

0 comments on commit 27531f6

Please sign in to comment.