Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: autotest #1128

Merged
merged 6 commits into from
Aug 31, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 144 additions & 50 deletions inst/testthat/helper_autotest.R
Original file line number Diff line number Diff line change
@@ -1,17 +1,37 @@
# Learner autotest suite
#
# `run_experiment(task, learner)` runs a single experiment.
# Returns a list with success flag "status" (`logical(1)`),
# "experiment" (partially constructed experiment), and "error"
# (informative error message).
#
# `run_autotest(learner)` generates multiple tasks, depending on the properties of the learner.
# and tests the learner on each task, with each predict type.
# To debug, simply run `result = run_autotest(learner)` and proceed with investigating
# the task, learner and prediction of the returned `result`.
#' @title Learner Autotest Suite
#'
#' @description
#' The autotest suite is a collection of functions to test learners in a standardized way.
#' Extension packages need to specialize the S3 methods in the file.
#
# NB: Extension packages need to specialize the S3 methods in the file.
#' @details
#' `run_autotest(learner)` generates multiple tasks, depending on the properties of the learner and tests the learner on each task, with each predict type.
#' Calls `generate_tasks()` to generate tasks and `run_experiment()` to run the experiments.
#' To debug, simply run `result = run_autotest(learner)` and proceed with investigating he task, learner and prediction of the returned `result`.
#'
#' `run_experiment(task, learner)` runs a single experiment.
#' Calls `train()` and `predict()` on the learner and checks the prediction with `score()`.
#' The prediction is checked with `sanity_check()`.
#'
#' `generate_tasks(learner)` generates multiple tasks for a given learner.
#' Calls `generate_data()` and `generate_generic_tasks()` to generate tasks with different feature types.
#'
#' @noRd
NULL

#' @title Generate Tasks for a Learner
be-marc marked this conversation as resolved.
Show resolved Hide resolved
#'
#' @description
#' Generates multiple tasks for a given [Learner], based on its properties.
#'
#' @param learner [Learner]\cr
#' Learner to generate tasks for.
#' @param proto [Task]\cr
#' Prototype task to generate tasks from.
#'
#' @return (List of [Task]s).
#'
#' @noRd
generate_generic_tasks = function(learner, proto) {
tasks = list()
n = proto$nrow
Expand Down Expand Up @@ -76,6 +96,20 @@ generate_generic_tasks = function(learner, proto) {
})
}

#' @title Generate Data for a Learner
#'
#' @description
#' Generates data for a given [Learner], based on its supported feature types.
#' Data is created for logical, integer, numeric, character, factor, ordered, and POSIXct features.
#'
#' @param learner [Learner]\cr
#' Learner to generate data for.
#' @param N `integer(1)`\cr
#' Number of rows of generated data.
#'
#' @return [data.table::data.table()]
#'
#' @noRd
generate_data = function(learner, N) {
generate_feature = function(type) {
switch(type,
Expand All @@ -96,14 +130,14 @@ generate_data = function(learner, N) {
#'
#' @description
#' Generates multiple tasks for a given [Learner], based on its properties.
#' This function is primarily used for unit tests, but can also assist while
#' writing custom learners.
#' This function is primarily used for unit tests, but can also assist while writing custom learners.
#'
#' @param learner :: [Learner].
#' @param N :: `integer(1)`\cr
#' @param learner [Learner]\cr
#' Learner to generate tasks for.
#' @param N `integer(1)`\cr
#' Number of rows of generated tasks.
#'
#' @return (List of [Task]s).
#' @return `list` of [Task]s
#' @keywords internal
#' @export
#' @examples
Expand Down Expand Up @@ -184,6 +218,17 @@ generate_tasks.LearnerRegr = function(learner, N = 30L) {
}
registerS3method("generate_tasks", "LearnerRegr", generate_tasks.LearnerRegr)

#' @title Sanity Check for Predictions
#'
#' @description
#' Checks the sanity of a prediction.
#'
#' @param prediction [Prediction]\cr
#' Prediction to check.
#'
#' @return (`logical(1)`).
#'
#' @noRd
sanity_check = function(prediction, ...) {
UseMethod("sanity_check")
}
Expand All @@ -199,7 +244,34 @@ sanity_check.PredictionRegr = function(prediction, ...) {
}
registerS3method("sanity_check", "LearnerRegr", sanity_check.PredictionRegr)


#' @title Run a Single Learner Test
#'
#' @description
#' Runs a single experiment with a given task and learner.
#'
#' @param task [Task]\cr
#' Task to run the experiment on.
#' @param learner [Learner]\cr
#' Learner to run the experiment with.
#' @param seed `integer(1)`\cr
#' Seed to use for the experiment.
#' If `NULL`, a random seed is generated.
#' @param configure_learner `function(learner, task)`\cr
#' Function to configure the learner before training.
#' Useful when learner settings need to be adjusted for a specific task.
#'
#' @return `list` with the following elements:
#' - `ok` (`logical(1)`): Success flag.
#' - `learner` ([Learner]): Learner used for the experiment.
#' - `prediction` ([Prediction]): Prediction object.
#' - `error` (`character()`): Error message if `ok` is `FALSE`.
#' - `seed` (`integer(1)`): Seed used for the experiment.
#'
#' @noRd
run_experiment = function(task, learner, seed = NULL, configure_learner = NULL) {

# function to collect error message and objects
err = function(info, ...) {
info = sprintf(info, ...)
list(
Expand All @@ -210,6 +282,7 @@ run_experiment = function(task, learner, seed = NULL, configure_learner = NULL)
)
}

# seed handling
if (is.null(seed)) {
seed = sample.int(floor(.Machine$integer.max / 2L), 1L)
}
Expand All @@ -230,31 +303,27 @@ run_experiment = function(task, learner, seed = NULL, configure_learner = NULL)
}
prediction = NULL
score = NULL
learner$encapsulate = c(train = "evaluate", predict = "evaluate")

# check train
stage = "train()"

ok = try(learner$train(task), silent = TRUE)
if (inherits(ok, "try-error")) {
return(err(as.character(ok)))
}
log = learner$log[stage == "train"]
if ("error" %in% log$class) {
return(err("train log has errors: %s", mlr3misc::str_collapse(log[class == "error", msg])))
}
if (is.null(learner$model)) {
return(err("model is NULL"))
}

# check predict
stage = "predict()"

prediction = try(learner$predict(task), silent = TRUE)
if (inherits(ok, "try-error")) {
if (inherits(prediction, "try-error")) {
ok = prediction
prediction = NULL
return(err(as.character(ok)))
}
log = learner$log[stage == "predict"]
if ("error" %in% log$class) {
return(err("predict log has errors: %s", mlr3misc::str_collapse(log[class == "error", msg])))
}
msg = checkmate::check_class(prediction, "Prediction")
if (!isTRUE(msg)) {
return(err(msg))
Expand Down Expand Up @@ -294,28 +363,31 @@ run_experiment = function(task, learner, seed = NULL, configure_learner = NULL)
}
}


# check score
stage = "score()"

score = try(
prediction$score(mlr3::default_measures(learner$task_type),
task = task,
learner = learner,
train_set = task$row_ids
), silent = TRUE)
if (inherits(score, "try-error")) {
return(err(as.character(score)))
ok = score
score = NULL
return(err(as.character(ok)))
}
msg = checkmate::check_numeric(score, any.missing = FALSE)
if (!isTRUE(msg)) {
return(err(msg))
}

# run sanity check on sanity task
if (startsWith(task$id, "sanity") && !
sanity_check(prediction, task = task, learner = learner, train_set = task$row_ids)) {
if (startsWith(task$id, "sanity") && !sanity_check(prediction, task = task, learner = learner, train_set = task$row_ids)) {
return(err("sanity check failed"))
}

# check importance, selected_features and oob_error methods
if (startsWith(task$id, "feat_all")) {
if ("importance" %in% learner$properties) {
importance = learner$importance()
Expand Down Expand Up @@ -352,30 +424,56 @@ run_experiment = function(task, learner, seed = NULL, configure_learner = NULL)
return(list(ok = TRUE, learner = learner, prediction = prediction, error = character(), seed = seed))
}

#' @title Run Autotest for a Learner
#'
#' @description
#' Runs a series of experiments with a given learner on multiple tasks.
#'
#' @param learner ([Learner])\cr
#' The learner to test.
#' @param N (`integer(1)`)\cr
#' Number of rows of generated tasks.
#' @param exclude (`character()`)\cr
#' Regular expression to exclude tasks from the test.
#' Run `generate_tasks(learner)` to see all available tasks.
#' @param predict_types (`character()`)\cr
#' Predict types to test.
#' @param check_replicable (`logical(1)`)\cr
#' Check if the results are replicable.
#' @param configure_learner (`function(learner, task)`)\cr
#' Function to configure the learner before training.
#' Useful when learner settings need to be adjusted for a specific task.
#'
#' @return If the test was successful, `TRUE` is returned.
#' If the test failed, a `list` with the following elements is returned:
#' - `ok` (`logical(1)`): Success flag.
#' - `seed` (`integer(1)`): Seed used for the experiment.
#' - `task` ([Task]): Task used for the experiment.
#' - `learner` ([Learner]): Learner used for the experiment.
#' - `prediction` ([Prediction]): Prediction object.
#' - `score` (`numeric(1)`): Score of the prediction.
#' - `error` (`character()`): Error message if `ok` is `FALSE`.
#
#' @noRd
run_autotest = function(learner, N = 30L, exclude = NULL, predict_types = learner$predict_types, check_replicable = TRUE, configure_learner = NULL) { # nolint
if (!is.null(configure_learner)) {
checkmate::assert_function(configure_learner, args = c("learner", "task"))
}
learner = learner$clone(deep = TRUE)
id = learner$id
tasks = generate_tasks(learner, N = N)

if (!is.null(exclude)) {
tasks = tasks[!grepl(exclude, names(tasks))]
}


sanity_runs = list()
make_err = function(msg, ...) {
run$ok = FALSE
run$error = sprintf(msg, ...)
run
}

# param_tags = unique(unlist(learner$param_set$tags))
# if (!test_subset(param_tags, mlr_reflections$learner_param_tags)) {
# return(list(ok = FALSE, error = "Invalid parameter tag(s), check `mlr_reflections$learner_param_tags`."))
# }

for (task in tasks) {
for (predict_type in predict_types) {
learner$id = sprintf("%s:%s", id, predict_type)
Expand Down Expand Up @@ -415,26 +513,22 @@ run_autotest = function(learner, N = 30L, exclude = NULL, predict_types = learne
}
}



return(TRUE)
}

#' @title Check Parameters of mlr3 Learners
#' @description Checks parameters of mlr3learners against parameters defined in
#' the upstream functions of the respective learners.
#'
#' @description
#' Checks parameters of mlr3learners against parameters defined in the upstream functions of the respective learners.
#'
#' @details
#' Some learners do not have all of their parameters stored within the learner
#' function that is called within `.train()`. Sometimes learners come with a
#' "control" function, e.g. [glmnet::glmnet.control()]. Such need to be checked
#' as well since they make up the full ParamSet of the respective learner.
#' Some learners do not have all of their parameters stored within the learner function that is called within `.train()`.
#' Sometimes learners come with a "control" function, e.g. [glmnet::glmnet.control()].
#' Such need to be checked as well since they make up the full ParamSet of the respective learner.
#'
#' To work nicely with the defined ParamSet, certain parameters need to be
#' excluded because these are only present in either the "control" object or the
#' actual top-level function call. Such exclusions should go into argument
#' `exclude` with a comment for the reason of the exclusion. See examples for
#' more information.
#' To work nicely with the defined ParamSet, certain parameters need to be excluded because these are only present in either the "control" object or the actual top-level function call.
#' Such exclusions should go into argument `exclude` with a comment for the reason of the exclusion.
#' See examples for more information.
#'
#' @param learner ([mlr3::Learner])\cr
#' The constructed learner.
Expand Down
2 changes: 2 additions & 0 deletions man/mlr_test_helpers.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 27 additions & 0 deletions tests/testthat/test_autotest.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
test_that("autotest catches error in train", {
learner = lrn("classif.debug", error_train = 1)
task = tsk("spam")

result = run_experiment(task, learner)
expect_false(result$ok)
expect_integer(result$seed)
expect_task(result$task)
expect_learner(result$learner)
expect_null(result$prediction)
expect_null(result$score)
expect_string(result$error)
})

test_that("autotest catches error in predict", {
learner = lrn("classif.debug", error_predict = 1)
task = tsk("spam")

result = run_experiment(task, learner)
expect_false(result$ok)
expect_integer(result$seed)
expect_task(result$task)
expect_learner(result$learner)
expect_null(result$prediction)
expect_null(result$score)
expect_string(result$error)
})