Skip to content

Commit

Permalink
feat: improve docs for converters and better checks (#1231)
Browse files Browse the repository at this point in the history
* feat: improve docs for converters and better checks

* ...

* fix failing tests
  • Loading branch information
sebffischer authored Jan 6, 2025
1 parent 5c24ba8 commit e21782a
Show file tree
Hide file tree
Showing 22 changed files with 133 additions and 15 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ S3method(as_resampling,Resampling)
S3method(as_resamplings,default)
S3method(as_resamplings,list)
S3method(as_task,Task)
S3method(as_task,default)
S3method(as_task_classif,DataBackend)
S3method(as_task_classif,Matrix)
S3method(as_task_classif,TaskClassif)
Expand Down Expand Up @@ -200,6 +201,7 @@ export(as_tasks)
export(as_tasks_unsupervised)
export(assert_backend)
export(assert_benchmark_result)
export(assert_empty_ellipsis)
export(assert_learnable)
export(assert_learner)
export(assert_learners)
Expand Down
18 changes: 15 additions & 3 deletions R/BenchmarkResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,11 @@ BenchmarkResult = R6Class("BenchmarkResult",
#'
#' @return [data.table::data.table()].
score = function(measures = NULL, ids = TRUE, conditions = FALSE, predictions = TRUE) {
measures = as_measures(measures, task_type = self$task_type)
measures = if (is.null(measures)) {
default_measures(self$task_type)
} else {
assert_measures(as_measures(measures))
}
assert_flag(ids)
assert_flag(conditions)
assert_flag(predictions)
Expand Down Expand Up @@ -230,7 +234,11 @@ BenchmarkResult = R6Class("BenchmarkResult",
#' @param predict_sets (`character()`)\cr
#' The predict sets.
obs_loss = function(measures = NULL, predict_sets = "test") {
measures = as_measures(measures, task_type = private$.data$task_type)
measures = if (is.null(measures)) {
default_measures(self$task_type)
} else {
assert_measures(as_measures(measures))
}
map_dtr(self$resample_results$resample_result,
function(rr) {
rr$obs_loss(measures, predict_sets)
Expand Down Expand Up @@ -276,7 +284,11 @@ BenchmarkResult = R6Class("BenchmarkResult",
#'
#' @return [data.table::data.table()].
aggregate = function(measures = NULL, ids = TRUE, uhashes = FALSE, params = FALSE, conditions = FALSE) {
measures = assert_measures(as_measures(measures, task_type = self$task_type))
measures = if (is.null(measures)) {
default_measures(self$task_type)
} else {
assert_measures(as_measures(measures))
}
assert_flag(ids)
assert_flag(uhashes)
assert_flag(params)
Expand Down
12 changes: 10 additions & 2 deletions R/Prediction.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ Prediction = R6Class("Prediction",
#'
#' @return [Prediction].
score = function(measures = NULL, task = NULL, learner = NULL, train_set = NULL) {
measures = as_measures(measures, task_type = self$task_type)
measures = if (is.null(measures)) {
default_measures(self$task_type)
} else {
assert_measures(as_measures(measures))
}
scores = map_dbl(measures, function(m) m$score(prediction = self, task = task, learner = learner, train_set = train_set))
set_names(scores, ids(measures))
},
Expand All @@ -105,7 +109,11 @@ Prediction = R6Class("Prediction",
#' Note that some measures such as RMSE, do have an `$obs_loss`, but they require an
#' additional transformation after aggregation, in this example taking the square-root.
obs_loss = function(measures = NULL) {
measures = as_measures(measures, task_type = self$task_type)
measures = if (is.null(measures)) {
default_measures(self$task_type)
} else {
assert_measures(as_measures(measures))
}
get_obs_loss(as.data.table(self), measures)
},

Expand Down
18 changes: 15 additions & 3 deletions R/ResampleResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,11 @@ ResampleResult = R6Class("ResampleResult",
#'
#' @return [data.table::data.table()].
score = function(measures = NULL, ids = TRUE, conditions = FALSE, predictions = TRUE) {
measures = as_measures(measures, task_type = private$.data$task_type)
measures = if (is.null(measures)) {
default_measures(self$task_type)
} else {
assert_measures(as_measures(measures))
}
assert_flag(ids)
assert_flag(conditions)
assert_flag(predictions)
Expand Down Expand Up @@ -196,7 +200,11 @@ ResampleResult = R6Class("ResampleResult",
#' @param predict_sets (`character()`)\cr
#' The predict sets.
obs_loss = function(measures = NULL, predict_sets = "test") {
measures = as_measures(measures, task_type = self$task_type)
measures = if (is.null(measures)) {
default_measures(self$task_type)
} else {
assert_measures(as_measures(measures))
}
tab = map_dtr(self$predictions(predict_sets), as.data.table, .idcol = "iteration")
get_obs_loss(tab, measures)
},
Expand All @@ -208,7 +216,11 @@ ResampleResult = R6Class("ResampleResult",
#'
#' @return Named `numeric()`.
aggregate = function(measures = NULL) {
measures = as_measures(measures, task_type = private$.data$task_type)
measures = if (is.null(measures)) {
default_measures(self$task_type)
} else {
assert_measures(as_measures(measures))
}
resample_result_aggregate(self, measures)
},

Expand Down
1 change: 1 addition & 0 deletions R/as_learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ as_learner = function(x, ...) { # nolint
#' Whether to discard the state.
#' @rdname as_learner
as_learner.Learner = function(x, clone = FALSE, discard_state = FALSE, ...) { # nolint
assert_empty_ellipsis(...)
if (isTRUE(clone) && isTRUE(discard_state)) {
clone_without(x, "state")
} else if (isTRUE(clone)) {
Expand Down
2 changes: 2 additions & 0 deletions R/as_measure.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ as_measure = function(x, ...) { # nolint
#' @export
#' @rdname as_measure
as_measure.NULL = function(x, task_type = NULL, ...) { # nolint
assert_empty_ellipsis(...)
default_measures(task_type)[[1L]]
}

#' @export
#' @rdname as_measure
as_measure.Measure = function(x, clone = FALSE, ...) { # nolint
assert_empty_ellipsis(...)
if (isTRUE(clone)) x$clone() else x
}

Expand Down
3 changes: 2 additions & 1 deletion R/as_resampling.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#'
#' @description
#' Convert object to a [Resampling] or a list of [Resampling].
#'
#' This method e.g. allows to convert an [`mlr3oml::OMLTask`] to a [`Resampling`].
#' @inheritParams as_task
#' @export
as_resampling = function(x, ...) { # nolint
Expand All @@ -12,6 +12,7 @@ as_resampling = function(x, ...) { # nolint
#' @export
#' @rdname as_resampling
as_resampling.Resampling = function(x, clone = FALSE, ...) { # nolint
assert_empty_ellipsis(...)
if (isTRUE(clone)) x$clone() else x
}

Expand Down
8 changes: 8 additions & 0 deletions R/as_task.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#'
#' @description
#' Convert object to a [Task] or a list of [Task].
#' This method e.g. allows to convert an [`mlr3oml::OMLTask`] to a [`Task`] and additionally supports cloning.
#' In order to construct a [Task] from a `data.frame`, use task-specific converters such as [`as_task_classif()`] or [`as_task_regr()`].
#'
#' @param x (any)\cr
#' Object to convert.
Expand All @@ -12,11 +14,17 @@ as_task = function(x, ...) {
UseMethod("as_task")
}

#' @export
as_task.default = function(x, ...) {
stopf("No method for class '%s'. To create a task from a `data.frame`, use dedicated converters such as `as_task_classif()` or `as_task_regr()`.", class(x)[1L])
}

#' @rdname as_task
#' @param clone (`logical(1)`)\cr
#' If `TRUE`, ensures that the returned object is not the same as the input `x`.
#' @export
as_task.Task = function(x, clone = FALSE, ...) { # nolint
assert_empty_ellipsis(...)
if (isTRUE(clone)) x$clone(deep = TRUE) else x
}

Expand Down
2 changes: 1 addition & 1 deletion R/as_task_classif.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#' Convert object to a [TaskClassif].
#' This is a S3 generic. mlr3 ships with methods for the following objects:
#'
#' 1. [TaskClassif]: ensure the identity
#' 1. [TaskClassif]: returns the object as-is, possibly cloned.
#' 2. [`formula`], [data.frame()], [matrix()], [Matrix::Matrix()] and [DataBackend]: provides an alternative to the constructor of [TaskClassif].
#' 3. [TaskRegr]: Calls [convert_task()].
#'
Expand Down
2 changes: 1 addition & 1 deletion R/as_task_regr.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#' Convert object to a [TaskRegr].
#' This is a S3 generic. mlr3 ships with methods for the following objects:
#'
#' 1. [TaskRegr]: ensure the identity
#' 1. [TaskRegr]: returns the object as-is, possibly cloned.
#' 2. [`formula`], [data.frame()], [matrix()], [Matrix::Matrix()] and [DataBackend]: provides an alternative to the constructor of [TaskRegr].
#' 3. [TaskClassif]: Calls [convert_task()].
#'
Expand Down
26 changes: 26 additions & 0 deletions R/assertions.R
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,29 @@ assert_param_values = function(x, n_learners = NULL, .var.name = vname(x)) {
}
invisible(x)
}

#' @title Assert Empty Ellipsis
#' @description
#' Assert that `...` arguments are empty.
#' Use this function in S3-methods to ensure that misspelling of arguments does not go unnoticed.
#' @param ... (any)\cr
#' Ellipsis arguments to check.
#' @keywords internal
#' @return `NULL`
#' @export
assert_empty_ellipsis = function(...) {
if (...length()) {
names = ...names()
if (is.null(names)) {
stopf("Received %i unnamed argument that was not used.", ...length())
} else {
names2 = names[names != ""]
if (length(names2) == length(names)) {
stopf("Received the following named arguments that were unused: %s.", paste0(names2, collapse = ", "))
} else {
stopf("Received unused arguments: %i unnamed, as well as named arguments %s.", length(names) - length(names2), paste0(names2, collapse = ", "))
}
}
}
NULL
}
1 change: 1 addition & 0 deletions man/as_resampling.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/as_task.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/as_task_classif.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/as_task_regr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions man/assert_empty_ellipsis.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions tests/testthat/test_Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -551,8 +551,7 @@ test_that("learner state contains internal valid task information", {
test_that("validation task with 0 observations", {
learner = lrn("classif.debug", validate = "predefined")
task = tsk("iris")
task$internal_valid_task = integer(0)
expect_error({learner$train(task)}, "has 0 observations")
expect_warning({task$internal_valid_task = integer(0)})
})

test_that("column info is compared during predict", {
Expand Down
4 changes: 4 additions & 0 deletions tests/testthat/test_as_learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ test_that("discard_state", {
as_learner(learner3, clone = FALSE, discard_state = TRUE)
expect_null(learner3$state)
})

test_that("error when arguments are misspelled", {
expect_error(as_learner(lrn("classif.rpart"), clone2 = TRUE), "Received the following")
})
4 changes: 4 additions & 0 deletions tests/testthat/test_as_measure.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@ test_that("as_measure conversion", {
default = as_measures(NULL, task_type = "classif")
expect_list(default, types = "Measure")
})

test_that("error when arguments are misspelled", {
expect_error(as_measure(msr("classif.acc"), clone2 = TRUE), "Received the following")
})
4 changes: 4 additions & 0 deletions tests/testthat/test_as_resampling.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,7 @@ test_that("as_resampling conversion", {
expect_list(as_resamplings(resampling), types = "Resampling")
expect_list(as_resamplings(list(resampling)), types = "Resampling")
})

test_that("error when arguments are misspelled", {
expect_error(as_resampling(rsmp("holdout"), clone2 = TRUE), "Received the following")
})
4 changes: 4 additions & 0 deletions tests/testthat/test_as_task.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@ test_that("as_task_xx error messages (#944)", {
"subset of"
)
})

test_that("error when arguments are misspelled", {
expect_error(as_task(tsk("iris"), clone2 = TRUE), "Received the following")
})
8 changes: 8 additions & 0 deletions tests/testthat/test_assertions.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
test_that("assert_empty_ellipsis works", {
expect_error(assert_empty_ellipsis(1), "Received 1 unnamed argument")
expect_error(assert_empty_ellipsis(1, 2), "Received 2 unnamed argument")
expect_error(assert_empty_ellipsis(a = 1), "that were unused: a")
expect_error(assert_empty_ellipsis(a = 1, b = 2), "that were unused: a, b")
expect_error(assert_empty_ellipsis(a = 1, b = 1, 2), "1 unnamed, as well as named arguments a, b")
expect_null(assert_empty_ellipsis())
})

0 comments on commit e21782a

Please sign in to comment.