diff --git a/NEWS.md b/NEWS.md index 582f24f99..8da711c30 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,6 @@ # mlr3 (development version) +* feat: Throw warning when prediction and measure type do not match. * fix: The `mlr_reflections` were broken when an extension package was not loaded on the workers. Extension packages must now register themselves in the `mlr_reflections$loaded_packages` field. diff --git a/R/Measure.R b/R/Measure.R index 649dd077c..a398c318e 100644 --- a/R/Measure.R +++ b/R/Measure.R @@ -194,7 +194,7 @@ Measure = R6Class("Measure", #' #' @return `numeric(1)`. score = function(prediction, task = NULL, learner = NULL, train_set = NULL) { - assert_measure(self, task = task, learner = learner) + assert_measure(self, task = task, learner = learner, prediction = prediction) assert_prediction(prediction, null.ok = "requires_no_prediction" %nin% self$properties) if ("requires_task" %in% self$properties && is.null(task)) { diff --git a/R/assertions.R b/R/assertions.R index cc5be49dc..af9dabf65 100644 --- a/R/assertions.R +++ b/R/assertions.R @@ -194,8 +194,9 @@ assert_predictable = function(task, learner) { #' @export #' @param measure ([Measure]). +#' @param prediction ([Prediction]). #' @rdname mlr_assertions -assert_measure = function(measure, task = NULL, learner = NULL, .var.name = vname(measure)) { +assert_measure = function(measure, task = NULL, learner = NULL, prediction = NULL, .var.name = vname(measure)) { assert_class(measure, "Measure", .var.name = .var.name) if (!is.null(task)) { @@ -236,6 +237,13 @@ assert_measure = function(measure, task = NULL, learner = NULL, .var.name = vnam } } + if (!is.null(prediction)) { + # same as above but works without learner e.g. measure$score(prediction) + if (measure$check_prerequisites != "ignore" && measure$predict_type %nin% prediction$predict_types) { + warningf("Measure '%s' is missing predict type '%s' of prediction", measure$id, measure$predict_type) + } + } + invisible(measure) } diff --git a/man/mlr_assertions.Rd b/man/mlr_assertions.Rd index 008bdbf10..e88f7625d 100644 --- a/man/mlr_assertions.Rd +++ b/man/mlr_assertions.Rd @@ -62,6 +62,7 @@ assert_measure( measure, task = NULL, learner = NULL, + prediction = NULL, .var.name = vname(measure) ) @@ -115,14 +116,14 @@ Set of required task properties.} \item{measure}{(\link{Measure}).} +\item{prediction}{(\link{Prediction}).} + \item{measures}{(list of \link{Measure}).} \item{resampling}{(\link{Resampling}).} \item{resamplings}{(list of \link{Resampling}).} -\item{prediction}{(\link{Prediction}).} - \item{rr}{(\link{ResampleResult}).} \item{bmr}{(\link{BenchmarkResult}).} diff --git a/tests/testthat/test_Measure.R b/tests/testthat/test_Measure.R index 786c94b84..bb343dd64 100644 --- a/tests/testthat/test_Measure.R +++ b/tests/testthat/test_Measure.R @@ -188,3 +188,13 @@ test_that("checks on predict_sets", { expect_error({m$predict_sets = "imaginary"}, "Must be a subset") }) +test_that("measure and prediction type is checked", { + learner = lrn("classif.rpart") + task = tsk("pima") + learner$train(task) + pred = learner$predict(task) + + measure = msr("classif.logloss") + expect_warning(measure$score(pred), "is missing predict type") +}) +