Skip to content

Commit

Permalink
docs: autotest (#1128)
Browse files Browse the repository at this point in the history
* refactor: autotest

* chore: return

* docs: update

* chore: add link

* docs: render

* docs: add tasks
  • Loading branch information
be-marc authored Aug 31, 2024
1 parent 57b6109 commit 7190f45
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 51 deletions.
3 changes: 2 additions & 1 deletion R/mlr_test_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#' the task, learner and prediction of the returned `result`.
#'
#' For example usages you can look at the autotests in various mlr3 source repositories such as mlr3learners.
#' More information can be found in the `inst/testthat/autotest.R` file.
#'
#' **Parameters**:
#'
Expand All @@ -42,7 +43,7 @@
#' Whether to check that running the learner twice with the same seed should result in identical predictions.
#' Default is `TRUE`.
#' * `configure_learner` (`function(learner, task)`)\cr
#' Before running a `learner` on a `task`, this function allows to change its parameter values depending on the input task.
#' Before running a `learner` on a `task`, this function allows to change its parameter values depending on the input task.
#'
#' @section run_paramtest():
#'
Expand Down
203 changes: 153 additions & 50 deletions inst/testthat/helper_autotest.R
Original file line number Diff line number Diff line change
@@ -1,17 +1,38 @@
# Learner autotest suite
#
# `run_experiment(task, learner)` runs a single experiment.
# Returns a list with success flag "status" (`logical(1)`),
# "experiment" (partially constructed experiment), and "error"
# (informative error message).
#
# `run_autotest(learner)` generates multiple tasks, depending on the properties of the learner.
# and tests the learner on each task, with each predict type.
# To debug, simply run `result = run_autotest(learner)` and proceed with investigating
# the task, learner and prediction of the returned `result`.
#' @title Learner Autotest Suite
#'
#' @description
#' The autotest suite is a collection of functions to test learners in a standardized way.
#' Extension packages need to specialize the S3 methods in the file.
#
# NB: Extension packages need to specialize the S3 methods in the file.
#' @details
#' `run_autotest(learner)` generates multiple tasks, depending on the properties of the learner and tests the learner on each task, with each predict type.
#' Calls `generate_tasks()` to generate tasks and `run_experiment()` to run the experiments.
#' See `generate_tasks()` for a list of tasks that are generated.
#' To debug, simply run `result = run_autotest(learner)` and proceed with investigating he task, learner and prediction of the returned `result`.
#'
#' `run_experiment(task, learner)` runs a single experiment.
#' Calls `train()` and `predict()` on the learner and checks the prediction with `score()`.
#' The prediction is checked with `sanity_check()`.
#'
#' `generate_tasks(learner)` generates multiple tasks for a given learner.
#' Calls `generate_data()` and `generate_generic_tasks()` to generate tasks with different feature types.
#'
#' @noRd
NULL

#' @title Generate Tasks for a Learner
#'
#' @description
#' Generates multiple tasks for a given [Learner], based on its properties.
#'
#' @param learner [Learner]\cr
#' Learner to generate tasks for.
#' @param proto [Task]\cr
#' Prototype task to generate tasks from.
#'
#' @return (List of [Task]s).
#'
#' @noRd
generate_generic_tasks = function(learner, proto) {
tasks = list()
n = proto$nrow
Expand Down Expand Up @@ -76,6 +97,20 @@ generate_generic_tasks = function(learner, proto) {
})
}

#' @title Generate Data for a Learner
#'
#' @description
#' Generates data for a given [Learner], based on its supported feature types.
#' Data is created for logical, integer, numeric, character, factor, ordered, and POSIXct features.
#'
#' @param learner [Learner]\cr
#' Learner to generate data for.
#' @param N `integer(1)`\cr
#' Number of rows of generated data.
#'
#' @return [data.table::data.table()]
#'
#' @noRd
generate_data = function(learner, N) {
generate_feature = function(type) {
switch(type,
Expand All @@ -96,14 +131,22 @@ generate_data = function(learner, N) {
#'
#' @description
#' Generates multiple tasks for a given [Learner], based on its properties.
#' This function is primarily used for unit tests, but can also assist while
#' writing custom learners.
#' This function is primarily used for unit tests, but can also assist while writing custom learners.
#' The following tasks are created:
#' * `feat_single_*`: Tasks with a single feature type.
#' * `feat_all_*`: Task with all supported feature types.
#' * `missings_*`: Task with missing values.
#' * `utf8_feature_names_*`: Task with non-ascii feature names.
#' * `sanity`: Task with a simple dataset to check if the learner is working.
#' * `sanity_reordered`: Task with the same dataset as `sanity`, but with reordered columns.
#' * `sanity_switched`: Task with the same dataset as `sanity`, but with the positive class switched.
#'
#' @param learner :: [Learner].
#' @param N :: `integer(1)`\cr
#' @param learner [Learner]\cr
#' Learner to generate tasks for.
#' @param N `integer(1)`\cr
#' Number of rows of generated tasks.
#'
#' @return (List of [Task]s).
#' @return `list` of [Task]s
#' @keywords internal
#' @export
#' @examples
Expand Down Expand Up @@ -184,6 +227,17 @@ generate_tasks.LearnerRegr = function(learner, N = 30L) {
}
registerS3method("generate_tasks", "LearnerRegr", generate_tasks.LearnerRegr)

#' @title Sanity Check for Predictions
#'
#' @description
#' Checks the sanity of a prediction.
#'
#' @param prediction [Prediction]\cr
#' Prediction to check.
#'
#' @return (`logical(1)`).
#'
#' @noRd
sanity_check = function(prediction, ...) {
UseMethod("sanity_check")
}
Expand All @@ -199,7 +253,34 @@ sanity_check.PredictionRegr = function(prediction, ...) {
}
registerS3method("sanity_check", "LearnerRegr", sanity_check.PredictionRegr)


#' @title Run a Single Learner Test
#'
#' @description
#' Runs a single experiment with a given task and learner.
#'
#' @param task [Task]\cr
#' Task to run the experiment on.
#' @param learner [Learner]\cr
#' Learner to run the experiment with.
#' @param seed `integer(1)`\cr
#' Seed to use for the experiment.
#' If `NULL`, a random seed is generated.
#' @param configure_learner `function(learner, task)`\cr
#' Function to configure the learner before training.
#' Useful when learner settings need to be adjusted for a specific task.
#'
#' @return `list` with the following elements:
#' - `ok` (`logical(1)`): Success flag.
#' - `learner` ([Learner]): Learner used for the experiment.
#' - `prediction` ([Prediction]): Prediction object.
#' - `error` (`character()`): Error message if `ok` is `FALSE`.
#' - `seed` (`integer(1)`): Seed used for the experiment.
#'
#' @noRd
run_experiment = function(task, learner, seed = NULL, configure_learner = NULL) {

# function to collect error message and objects
err = function(info, ...) {
info = sprintf(info, ...)
list(
Expand All @@ -210,6 +291,7 @@ run_experiment = function(task, learner, seed = NULL, configure_learner = NULL)
)
}

# seed handling
if (is.null(seed)) {
seed = sample.int(floor(.Machine$integer.max / 2L), 1L)
}
Expand All @@ -230,31 +312,27 @@ run_experiment = function(task, learner, seed = NULL, configure_learner = NULL)
}
prediction = NULL
score = NULL
learner$encapsulate = c(train = "evaluate", predict = "evaluate")

# check train
stage = "train()"

ok = try(learner$train(task), silent = TRUE)
if (inherits(ok, "try-error")) {
return(err(as.character(ok)))
}
log = learner$log[stage == "train"]
if ("error" %in% log$class) {
return(err("train log has errors: %s", mlr3misc::str_collapse(log[class == "error", msg])))
}
if (is.null(learner$model)) {
return(err("model is NULL"))
}

# check predict
stage = "predict()"

prediction = try(learner$predict(task), silent = TRUE)
if (inherits(ok, "try-error")) {
if (inherits(prediction, "try-error")) {
ok = prediction
prediction = NULL
return(err(as.character(ok)))
}
log = learner$log[stage == "predict"]
if ("error" %in% log$class) {
return(err("predict log has errors: %s", mlr3misc::str_collapse(log[class == "error", msg])))
}
msg = checkmate::check_class(prediction, "Prediction")
if (!isTRUE(msg)) {
return(err(msg))
Expand Down Expand Up @@ -294,28 +372,31 @@ run_experiment = function(task, learner, seed = NULL, configure_learner = NULL)
}
}


# check score
stage = "score()"

score = try(
prediction$score(mlr3::default_measures(learner$task_type),
task = task,
learner = learner,
train_set = task$row_ids
), silent = TRUE)
if (inherits(score, "try-error")) {
return(err(as.character(score)))
ok = score
score = NULL
return(err(as.character(ok)))
}
msg = checkmate::check_numeric(score, any.missing = FALSE)
if (!isTRUE(msg)) {
return(err(msg))
}

# run sanity check on sanity task
if (startsWith(task$id, "sanity") && !
sanity_check(prediction, task = task, learner = learner, train_set = task$row_ids)) {
if (startsWith(task$id, "sanity") && !sanity_check(prediction, task = task, learner = learner, train_set = task$row_ids)) {
return(err("sanity check failed"))
}

# check importance, selected_features and oob_error methods
if (startsWith(task$id, "feat_all")) {
if ("importance" %in% learner$properties) {
importance = learner$importance()
Expand Down Expand Up @@ -352,30 +433,56 @@ run_experiment = function(task, learner, seed = NULL, configure_learner = NULL)
return(list(ok = TRUE, learner = learner, prediction = prediction, error = character(), seed = seed))
}

#' @title Run Autotest for a Learner
#'
#' @description
#' Runs a series of experiments with a given learner on multiple tasks.
#'
#' @param learner ([Learner])\cr
#' The learner to test.
#' @param N (`integer(1)`)\cr
#' Number of rows of generated tasks.
#' @param exclude (`character()`)\cr
#' Regular expression to exclude tasks from the test.
#' Run `generate_tasks(learner)` to see all available tasks.
#' @param predict_types (`character()`)\cr
#' Predict types to test.
#' @param check_replicable (`logical(1)`)\cr
#' Check if the results are replicable.
#' @param configure_learner (`function(learner, task)`)\cr
#' Function to configure the learner before training.
#' Useful when learner settings need to be adjusted for a specific task.
#'
#' @return If the test was successful, `TRUE` is returned.
#' If the test failed, a `list` with the following elements is returned:
#' - `ok` (`logical(1)`): Success flag.
#' - `seed` (`integer(1)`): Seed used for the experiment.
#' - `task` ([Task]): Task used for the experiment.
#' - `learner` ([Learner]): Learner used for the experiment.
#' - `prediction` ([Prediction]): Prediction object.
#' - `score` (`numeric(1)`): Score of the prediction.
#' - `error` (`character()`): Error message if `ok` is `FALSE`.
#
#' @noRd
run_autotest = function(learner, N = 30L, exclude = NULL, predict_types = learner$predict_types, check_replicable = TRUE, configure_learner = NULL) { # nolint
if (!is.null(configure_learner)) {
checkmate::assert_function(configure_learner, args = c("learner", "task"))
}
learner = learner$clone(deep = TRUE)
id = learner$id
tasks = generate_tasks(learner, N = N)

if (!is.null(exclude)) {
tasks = tasks[!grepl(exclude, names(tasks))]
}


sanity_runs = list()
make_err = function(msg, ...) {
run$ok = FALSE
run$error = sprintf(msg, ...)
run
}

# param_tags = unique(unlist(learner$param_set$tags))
# if (!test_subset(param_tags, mlr_reflections$learner_param_tags)) {
# return(list(ok = FALSE, error = "Invalid parameter tag(s), check `mlr_reflections$learner_param_tags`."))
# }

for (task in tasks) {
for (predict_type in predict_types) {
learner$id = sprintf("%s:%s", id, predict_type)
Expand Down Expand Up @@ -415,26 +522,22 @@ run_autotest = function(learner, N = 30L, exclude = NULL, predict_types = learne
}
}



return(TRUE)
}

#' @title Check Parameters of mlr3 Learners
#' @description Checks parameters of mlr3learners against parameters defined in
#' the upstream functions of the respective learners.
#'
#' @description
#' Checks parameters of mlr3learners against parameters defined in the upstream functions of the respective learners.
#'
#' @details
#' Some learners do not have all of their parameters stored within the learner
#' function that is called within `.train()`. Sometimes learners come with a
#' "control" function, e.g. [glmnet::glmnet.control()]. Such need to be checked
#' as well since they make up the full ParamSet of the respective learner.
#' Some learners do not have all of their parameters stored within the learner function that is called within `.train()`.
#' Sometimes learners come with a "control" function, e.g. [glmnet::glmnet.control()].
#' Such need to be checked as well since they make up the full ParamSet of the respective learner.
#'
#' To work nicely with the defined ParamSet, certain parameters need to be
#' excluded because these are only present in either the "control" object or the
#' actual top-level function call. Such exclusions should go into argument
#' `exclude` with a comment for the reason of the exclusion. See examples for
#' more information.
#' To work nicely with the defined ParamSet, certain parameters need to be excluded because these are only present in either the "control" object or the actual top-level function call.
#' Such exclusions should go into argument `exclude` with a comment for the reason of the exclusion.
#' See examples for more information.
#'
#' @param learner ([mlr3::Learner])\cr
#' The constructed learner.
Expand Down
1 change: 1 addition & 0 deletions man/mlr_test_helpers.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 27 additions & 0 deletions tests/testthat/test_autotest.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
test_that("autotest catches error in train", {
learner = lrn("classif.debug", error_train = 1)
task = tsk("spam")

result = run_experiment(task, learner)
expect_false(result$ok)
expect_integer(result$seed)
expect_task(result$task)
expect_learner(result$learner)
expect_null(result$prediction)
expect_null(result$score)
expect_string(result$error)
})

test_that("autotest catches error in predict", {
learner = lrn("classif.debug", error_predict = 1)
task = tsk("spam")

result = run_experiment(task, learner)
expect_false(result$ok)
expect_integer(result$seed)
expect_task(result$task)
expect_learner(result$learner)
expect_null(result$prediction)
expect_null(result$score)
expect_string(result$error)
})

0 comments on commit 7190f45

Please sign in to comment.