Skip to content

Commit

Permalink
fix: as_measures (#1242)
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer authored Jan 9, 2025
1 parent e21782a commit 694e21c
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 52 deletions.
12 changes: 2 additions & 10 deletions R/BenchmarkResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,7 @@ BenchmarkResult = R6Class("BenchmarkResult",
#'
#' @return [data.table::data.table()].
score = function(measures = NULL, ids = TRUE, conditions = FALSE, predictions = TRUE) {
measures = if (is.null(measures)) {
default_measures(self$task_type)
} else {
assert_measures(as_measures(measures))
}
measures = assert_measures(as_measures(measures, task_type = self$task_type))
assert_flag(ids)
assert_flag(conditions)
assert_flag(predictions)
Expand Down Expand Up @@ -234,11 +230,7 @@ BenchmarkResult = R6Class("BenchmarkResult",
#' @param predict_sets (`character()`)\cr
#' The predict sets.
obs_loss = function(measures = NULL, predict_sets = "test") {
measures = if (is.null(measures)) {
default_measures(self$task_type)
} else {
assert_measures(as_measures(measures))
}
measures = assert_measures(as_measures(measures, task_type = self$task_type))
map_dtr(self$resample_results$resample_result,
function(rr) {
rr$obs_loss(measures, predict_sets)
Expand Down
12 changes: 2 additions & 10 deletions R/Prediction.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,7 @@ Prediction = R6Class("Prediction",
#'
#' @return [Prediction].
score = function(measures = NULL, task = NULL, learner = NULL, train_set = NULL) {
measures = if (is.null(measures)) {
default_measures(self$task_type)
} else {
assert_measures(as_measures(measures))
}
measures = assert_measures(as_measures(measures, task_type = self$task_type))
scores = map_dbl(measures, function(m) m$score(prediction = self, task = task, learner = learner, train_set = train_set))
set_names(scores, ids(measures))
},
Expand All @@ -109,11 +105,7 @@ Prediction = R6Class("Prediction",
#' Note that some measures such as RMSE, do have an `$obs_loss`, but they require an
#' additional transformation after aggregation, in this example taking the square-root.
obs_loss = function(measures = NULL) {
measures = if (is.null(measures)) {
default_measures(self$task_type)
} else {
assert_measures(as_measures(measures))
}
measures = assert_measures(as_measures(measures, task_type = self$task_type))
get_obs_loss(as.data.table(self), measures)
},

Expand Down
18 changes: 3 additions & 15 deletions R/ResampleResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,7 @@ ResampleResult = R6Class("ResampleResult",
#'
#' @return [data.table::data.table()].
score = function(measures = NULL, ids = TRUE, conditions = FALSE, predictions = TRUE) {
measures = if (is.null(measures)) {
default_measures(self$task_type)
} else {
assert_measures(as_measures(measures))
}
measures = assert_measures(as_measures(measures, task_type = self$task_type))
assert_flag(ids)
assert_flag(conditions)
assert_flag(predictions)
Expand Down Expand Up @@ -200,11 +196,7 @@ ResampleResult = R6Class("ResampleResult",
#' @param predict_sets (`character()`)\cr
#' The predict sets.
obs_loss = function(measures = NULL, predict_sets = "test") {
measures = if (is.null(measures)) {
default_measures(self$task_type)
} else {
assert_measures(as_measures(measures))
}
measures = assert_measures(as_measures(measures, task_type = self$task_type))
tab = map_dtr(self$predictions(predict_sets), as.data.table, .idcol = "iteration")
get_obs_loss(tab, measures)
},
Expand All @@ -216,11 +208,7 @@ ResampleResult = R6Class("ResampleResult",
#'
#' @return Named `numeric()`.
aggregate = function(measures = NULL) {
measures = if (is.null(measures)) {
default_measures(self$task_type)
} else {
assert_measures(as_measures(measures))
}
measures = assert_measures(as_measures(measures, task_type = self$task_type))
resample_result_aggregate(self, measures)
},

Expand Down
14 changes: 7 additions & 7 deletions R/as_measure.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#'
#' @return [Measure].
#' @export
as_measure = function(x, ...) { # nolint
as_measure = function(x, task_type = NULL, ...) { # nolint
UseMethod("as_measure")
}

Expand All @@ -23,21 +23,21 @@ as_measure.NULL = function(x, task_type = NULL, ...) { # nolint

#' @export
#' @rdname as_measure
as_measure.Measure = function(x, clone = FALSE, ...) { # nolint
as_measure.Measure = function(x, task_type = NULL, clone = FALSE, ...) { # nolint
assert_empty_ellipsis(...)
if (isTRUE(clone)) x$clone() else x
}

#' @export
#' @rdname as_measure
as_measures = function(x, ...) { # nolint
as_measures = function(x, task_type = NULL, ...) { # nolint
UseMethod("as_measures")
}

#' @export
#' @rdname as_measure
as_measures.default = function(x, ...) { # nolint
list(as_measure(x, ...))
as_measures.default = function(x, task_type = NULL, ...) { # nolint
list(as_measure(x, task_type = task_type, ...))
}

#' @export
Expand All @@ -48,6 +48,6 @@ as_measures.NULL = function(x, task_type = NULL, ...) { # nolint

#' @export
#' @rdname as_measure
as_measures.list = function(x, ...) { # nolint
lapply(x, as_measure, ...)
as_measures.list = function(x, task_type = NULL, ...) { # nolint
lapply(x, as_measure, task_type = NULL, ...)
}
16 changes: 8 additions & 8 deletions man/as_measure.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mlr_learners_classif.featureless.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mlr_learners_regr.featureless.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 694e21c

Please sign in to comment.