Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(Learner): Allow for additional input checks on task #996

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# mlr3 (development version)

* feat(Learner): Add support for a `$check_learnable()` method that can be
implemented by `Learner`s to perform additional compatibility checks
* feat: dictionary conversion of `mlr_learners` respects prototype arguments
recently added in mlr3misc
* perf: skip unnecessary clone of learner's state in `resample()`
Expand Down
8 changes: 6 additions & 2 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#' The dictionary [mlr_learners] gets automatically populated with the new learners as soon as the respective packages are loaded.
#'
#' More (experimental) learners can be found in the GitHub repository: \url{https://github.com/mlr-org/mlr3extralearners}.
#' A guide on how to extend \CRANpkg{mlr3} with custom learners can be found in the [mlr3book](https://mlr3book.mlr-org.com).
#' A guide on how to extend \CRANpkg{mlr3} with custom learners can be found in [mlr3extralearners](https://mlr3extralearners.mlr-org.com/articles/extending.html).
#'
#' To combine the learner with preprocessing operations like factor encoding, \CRANpkg{mlr3pipelines} is recommended.
#' Hyperparameters stored in the `param_set` can be tuned with \CRANpkg{mlr3tuning}.
Expand All @@ -37,7 +37,6 @@
#' @template param_label
#' @template param_man
#'
#'
#' @section Optional Extractors:
#'
#' Specific learner implementations are free to implement additional getters to ease the access of certain parts
Expand Down Expand Up @@ -83,6 +82,11 @@
#' lrn$param_set$add(paradox::ParamFct$new("foo", levels = c("a", "b")))
#' ```
#'
#' @section Additional Task Checks:
#' Learner may perform custom compatibility checks on a task that determine whether a learner is applicable to a task
#' using the optional `$check_learnable(task)` public method.
#' When providing this method, it should either return `TRUE` or return an error message as a `character(1)`.
#'
#' @template seealso_learner
#' @export
Learner = R6Class("Learner",
Expand Down
6 changes: 6 additions & 0 deletions R/assertions.R
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ assert_learnable = function(task, learner) {
if (task$task_type == "unsupervised") {
stopf("%s cannot be trained with %s", learner$format(), task$format())
}
if (exists("check_learnable", envir = learner, inherits = FALSE)) {
msg = learner$check_learnable(task)
if (!isTRUE(msg)) {
stopf("Learner '%s' incompatible with task '%s': %s", learner$id, task$id, msg)
}
}
assert_task_learner(task, learner)
}

Expand Down
9 changes: 8 additions & 1 deletion man/Learner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 30 additions & 0 deletions tests/testthat/test_Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,33 @@ test_that("Models can be replaced", {
learner$model$location = 1
expect_equal(learner$model$location, 1)
})

test_that("check_learnable", {
learner = R6Class("LearnerClassifIris",
inherit = LearnerClassifRpart,
public = list(
check_learnable = function(task) {
if (task$id == "iris") {
return(TRUE)
}
"This learner can only be trained on iris."
}
)
)$new()
# assert_learnable
expect_error(assert_learnable(tsk("iris"), learner), regexp = NA)
expect_error(assert_learnable(tsk("sonar"), learner), regexp = "iris")

# benchmark() fails despite fallback
learner$fallback = lrn("classif.featureless")
expect_error(benchmark(benchmark_grid(tsk("iris"), learner, rsmp("holdout"))), regexp = NA)
expect_error(benchmark(benchmark_grid(tsk("sonar"), learner, rsmp("holdout"))), regexp = "iris")

# resample() fails despite fallback
expect_error(resample(tsk("iris"), learner, rsmp("holdout")), regexp = NA)
expect_error(resample(tsk("sonar"), learner, rsmp("holdout")), regexp = "iris")

# $train() does not trigger fallback
expect_error(learner$train(tsk("sonar")))
})

Loading