From 694e21c4cf97c22e2f2ddc7ae16441917ba32f78 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Thu, 9 Jan 2025 09:51:29 +0100 Subject: [PATCH] fix: as_measures (#1242) --- R/BenchmarkResult.R | 12 ++---------- R/Prediction.R | 12 ++---------- R/ResampleResult.R | 18 +++--------------- R/as_measure.R | 14 +++++++------- man/as_measure.Rd | 16 ++++++++-------- man/mlr_learners_classif.featureless.Rd | 2 +- man/mlr_learners_regr.featureless.Rd | 2 +- 7 files changed, 24 insertions(+), 52 deletions(-) diff --git a/R/BenchmarkResult.R b/R/BenchmarkResult.R index d3abea9bc..1ebd980a0 100644 --- a/R/BenchmarkResult.R +++ b/R/BenchmarkResult.R @@ -175,11 +175,7 @@ BenchmarkResult = R6Class("BenchmarkResult", #' #' @return [data.table::data.table()]. score = function(measures = NULL, ids = TRUE, conditions = FALSE, predictions = TRUE) { - measures = if (is.null(measures)) { - default_measures(self$task_type) - } else { - assert_measures(as_measures(measures)) - } + measures = assert_measures(as_measures(measures, task_type = self$task_type)) assert_flag(ids) assert_flag(conditions) assert_flag(predictions) @@ -234,11 +230,7 @@ BenchmarkResult = R6Class("BenchmarkResult", #' @param predict_sets (`character()`)\cr #' The predict sets. obs_loss = function(measures = NULL, predict_sets = "test") { - measures = if (is.null(measures)) { - default_measures(self$task_type) - } else { - assert_measures(as_measures(measures)) - } + measures = assert_measures(as_measures(measures, task_type = self$task_type)) map_dtr(self$resample_results$resample_result, function(rr) { rr$obs_loss(measures, predict_sets) diff --git a/R/Prediction.R b/R/Prediction.R index ad7a0c8ce..c2000e5ec 100644 --- a/R/Prediction.R +++ b/R/Prediction.R @@ -90,11 +90,7 @@ Prediction = R6Class("Prediction", #' #' @return [Prediction]. score = function(measures = NULL, task = NULL, learner = NULL, train_set = NULL) { - measures = if (is.null(measures)) { - default_measures(self$task_type) - } else { - assert_measures(as_measures(measures)) - } + measures = assert_measures(as_measures(measures, task_type = self$task_type)) scores = map_dbl(measures, function(m) m$score(prediction = self, task = task, learner = learner, train_set = train_set)) set_names(scores, ids(measures)) }, @@ -109,11 +105,7 @@ Prediction = R6Class("Prediction", #' Note that some measures such as RMSE, do have an `$obs_loss`, but they require an #' additional transformation after aggregation, in this example taking the square-root. obs_loss = function(measures = NULL) { - measures = if (is.null(measures)) { - default_measures(self$task_type) - } else { - assert_measures(as_measures(measures)) - } + measures = assert_measures(as_measures(measures, task_type = self$task_type)) get_obs_loss(as.data.table(self), measures) }, diff --git a/R/ResampleResult.R b/R/ResampleResult.R index 895696d54..53da71cdc 100644 --- a/R/ResampleResult.R +++ b/R/ResampleResult.R @@ -143,11 +143,7 @@ ResampleResult = R6Class("ResampleResult", #' #' @return [data.table::data.table()]. score = function(measures = NULL, ids = TRUE, conditions = FALSE, predictions = TRUE) { - measures = if (is.null(measures)) { - default_measures(self$task_type) - } else { - assert_measures(as_measures(measures)) - } + measures = assert_measures(as_measures(measures, task_type = self$task_type)) assert_flag(ids) assert_flag(conditions) assert_flag(predictions) @@ -200,11 +196,7 @@ ResampleResult = R6Class("ResampleResult", #' @param predict_sets (`character()`)\cr #' The predict sets. obs_loss = function(measures = NULL, predict_sets = "test") { - measures = if (is.null(measures)) { - default_measures(self$task_type) - } else { - assert_measures(as_measures(measures)) - } + measures = assert_measures(as_measures(measures, task_type = self$task_type)) tab = map_dtr(self$predictions(predict_sets), as.data.table, .idcol = "iteration") get_obs_loss(tab, measures) }, @@ -216,11 +208,7 @@ ResampleResult = R6Class("ResampleResult", #' #' @return Named `numeric()`. aggregate = function(measures = NULL) { - measures = if (is.null(measures)) { - default_measures(self$task_type) - } else { - assert_measures(as_measures(measures)) - } + measures = assert_measures(as_measures(measures, task_type = self$task_type)) resample_result_aggregate(self, measures) }, diff --git a/R/as_measure.R b/R/as_measure.R index 97f2798ef..bc63c32ab 100644 --- a/R/as_measure.R +++ b/R/as_measure.R @@ -10,7 +10,7 @@ #' #' @return [Measure]. #' @export -as_measure = function(x, ...) { # nolint +as_measure = function(x, task_type = NULL, ...) { # nolint UseMethod("as_measure") } @@ -23,21 +23,21 @@ as_measure.NULL = function(x, task_type = NULL, ...) { # nolint #' @export #' @rdname as_measure -as_measure.Measure = function(x, clone = FALSE, ...) { # nolint +as_measure.Measure = function(x, task_type = NULL, clone = FALSE, ...) { # nolint assert_empty_ellipsis(...) if (isTRUE(clone)) x$clone() else x } #' @export #' @rdname as_measure -as_measures = function(x, ...) { # nolint +as_measures = function(x, task_type = NULL, ...) { # nolint UseMethod("as_measures") } #' @export #' @rdname as_measure -as_measures.default = function(x, ...) { # nolint - list(as_measure(x, ...)) +as_measures.default = function(x, task_type = NULL, ...) { # nolint + list(as_measure(x, task_type = task_type, ...)) } #' @export @@ -48,6 +48,6 @@ as_measures.NULL = function(x, task_type = NULL, ...) { # nolint #' @export #' @rdname as_measure -as_measures.list = function(x, ...) { # nolint - lapply(x, as_measure, ...) +as_measures.list = function(x, task_type = NULL, ...) { # nolint + lapply(x, as_measure, task_type = NULL, ...) } diff --git a/man/as_measure.Rd b/man/as_measure.Rd index 79d52a565..ad17bc4a8 100644 --- a/man/as_measure.Rd +++ b/man/as_measure.Rd @@ -10,31 +10,31 @@ \alias{as_measures.list} \title{Convert to a Measure} \usage{ -as_measure(x, ...) +as_measure(x, task_type = NULL, ...) \method{as_measure}{`NULL`}(x, task_type = NULL, ...) -\method{as_measure}{Measure}(x, clone = FALSE, ...) +\method{as_measure}{Measure}(x, task_type = NULL, clone = FALSE, ...) -as_measures(x, ...) +as_measures(x, task_type = NULL, ...) -\method{as_measures}{default}(x, ...) +\method{as_measures}{default}(x, task_type = NULL, ...) \method{as_measures}{`NULL`}(x, task_type = NULL, ...) -\method{as_measures}{list}(x, ...) +\method{as_measures}{list}(x, task_type = NULL, ...) } \arguments{ \item{x}{(any)\cr Object to convert.} -\item{...}{(any)\cr -Additional arguments.} - \item{task_type}{(\code{character(1)})\cr Used if \code{x} is \code{NULL} to construct a default measure for the respective task type. The default measures are stored in \code{\link[=mlr_reflections]{mlr_reflections$default_measures}}.} +\item{...}{(any)\cr +Additional arguments.} + \item{clone}{(\code{logical(1)})\cr If \code{TRUE}, ensures that the returned object is not the same as the input \code{x}.} } diff --git a/man/mlr_learners_classif.featureless.Rd b/man/mlr_learners_classif.featureless.Rd index cf9724c9a..e701c7369 100644 --- a/man/mlr_learners_classif.featureless.Rd +++ b/man/mlr_learners_classif.featureless.Rd @@ -36,7 +36,7 @@ lrn("classif.featureless") \itemize{ \item Task type: \dQuote{classif} \item Predict Types: \dQuote{response}, \dQuote{prob} -\item Feature Types: \dQuote{logical}, \dQuote{integer}, \dQuote{numeric}, \dQuote{character}, \dQuote{factor}, \dQuote{ordered}, \dQuote{POSIXct} +\item Feature Types: \dQuote{logical}, \dQuote{integer}, \dQuote{numeric}, \dQuote{character}, \dQuote{factor}, \dQuote{ordered}, \dQuote{POSIXct}, \dQuote{Date} \item Required Packages: \CRANpkg{mlr3} } } diff --git a/man/mlr_learners_regr.featureless.Rd b/man/mlr_learners_regr.featureless.Rd index 4cd8de85e..64a69f4cb 100644 --- a/man/mlr_learners_regr.featureless.Rd +++ b/man/mlr_learners_regr.featureless.Rd @@ -25,7 +25,7 @@ lrn("regr.featureless") \itemize{ \item Task type: \dQuote{regr} \item Predict Types: \dQuote{response}, \dQuote{se}, \dQuote{quantiles} -\item Feature Types: \dQuote{logical}, \dQuote{integer}, \dQuote{numeric}, \dQuote{character}, \dQuote{factor}, \dQuote{ordered}, \dQuote{POSIXct} +\item Feature Types: \dQuote{logical}, \dQuote{integer}, \dQuote{numeric}, \dQuote{character}, \dQuote{factor}, \dQuote{ordered}, \dQuote{POSIXct}, \dQuote{Date} \item Required Packages: \CRANpkg{mlr3}, 'stats' } }