Skip to content

Commit

Permalink
Merge branch 'set_fallback' into featureless_quantiles
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Aug 31, 2024
2 parents 65cc2c5 + de532f7 commit 00cd22e
Show file tree
Hide file tree
Showing 25 changed files with 486 additions and 173 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ Collate:
'benchmark.R'
'benchmark_grid.R'
'bibentries.R'
'default_fallback.R'
'default_measures.R'
'fix_factor_levels.R'
'helper.R'
Expand All @@ -196,7 +197,6 @@ Collate:
'predict.R'
'reexports.R'
'resample.R'
'set_fallback.R'
'set_threads.R'
'set_validate.R'
'task_converters.R'
Expand Down
9 changes: 9 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ S3method(col_info,DataBackend)
S3method(col_info,data.table)
S3method(create_empty_prediction_data,TaskClassif)
S3method(create_empty_prediction_data,TaskRegr)
S3method(default_fallback,Learner)
S3method(default_fallback,LearnerClassif)
S3method(default_fallback,LearnerRegr)
S3method(default_values,Learner)
S3method(default_values,LearnerClassifRpart)
S3method(default_values,LearnerRegrRpart)
Expand Down Expand Up @@ -108,6 +111,11 @@ S3method(set_threads,list)
S3method(set_validate,Learner)
S3method(summary,Task)
S3method(tail,Task)
S3method(task_check_col_roles,Task)
S3method(task_check_col_roles,TaskClassif)
S3method(task_check_col_roles,TaskRegr)
S3method(task_check_col_roles,TaskSupervised)
S3method(task_check_col_roles,TaskUnsupervised)
S3method(unmarshal_model,classif.debug_model_marshaled)
S3method(unmarshal_model,default)
S3method(unmarshal_model,learner_state_marshaled)
Expand Down Expand Up @@ -241,6 +249,7 @@ export(rsmp)
export(rsmps)
export(set_threads)
export(set_validate)
export(task_check_col_roles)
export(tgen)
export(tgens)
export(tsk)
Expand Down
5 changes: 3 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
* refactor: Optimize runtime of marshalling.
* refactor: Optimize runtime of `Task$col_info`.
* fix: column info is now checked for compatibility during `Learner$predict` (#943).
* BREAKING CHANGE: the predict time of the learner now stores the cumulative duration for all predict sets (#992).
* BREAKING CHANGE: The predict time of the learner now stores the cumulative duration for all predict sets (#992).
* feat: `$internal_valid_task` can now be set to an `integer` vector.
* feat: Measures can now have an empty `$predict_sets` (#1094).
this is relevant for measures that only extract information from
Expand All @@ -20,7 +20,8 @@
* feat: Add option to calculate the mean of the true values on the train set in `msr("regr.rsq")`.
* feat: Default fallback learner is set when encapsulation is activated.
* feat: Learners classif.debug and regr.debug have new methods `$importance()` and `$selected_features()` for testing, also in downstream packages
* feat: Set default fallback with `set_fallback()`.
* feat: Create default fallback learner with `default_fallback()`.
* feat: Check column roles when using `$set_col_roles()` and `$col_roles`.

# mlr3 0.20.2

Expand Down
9 changes: 7 additions & 2 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -567,8 +567,13 @@ Learner = R6Class("Learner",
assert_names(names(rhs), subset.of = c("train", "predict"))
private$.encapsulate = insert_named(default, rhs)

# if there is no fallback, we get a default one from the reflections table
if (is.null(private$.fallback)) set_fallback(self)
# if there is no fallback, we get a default one
if (is.null(private$.fallback)) {
fallback = default_fallback(self)
if (!is.null(fallback)) {
self$fallback = fallback
}
}
},

#' @field fallback ([Learner])\cr
Expand Down
93 changes: 79 additions & 14 deletions R/Task.R
Original file line number Diff line number Diff line change
Expand Up @@ -670,9 +670,10 @@ Task = R6Class("Task",
#' Other column roles are preserved.
#'
#' @details
#' Roles are first set exclusively (argument `roles`), then added (argument `add_to`) and finally
#' removed (argument `remove_from`) from different roles.
#' Roles are first set exclusively (argument `roles`), then added (argument `add_to`) and finally removed (argument `remove_from`) from different roles.
#' Duplicated columns are removed from the same role.
#' For tasks that only allow one target, the target column cannot be set with `$set_col_roles()`.
#' Use the `$col_roles` field to swap the target column.
#'
#' @return
#' Returns the object itself, but modified **by reference**.
Expand Down Expand Up @@ -1188,7 +1189,25 @@ task_set_roles = function(li, elements, roles = NULL, add_to = NULL, remove_from
li
}

task_check_col_roles = function(self, new_roles) {
#' @title Check Column Roles
#'
#' @description
#' Internal function to check column roles.
#'
#' @param task ([Task])\cr
#' Task.
#' @param new_roles (`list()`)\cr
#' Column roles.
#'
#' @keywords internal
#' @export
task_check_col_roles = function(task, new_roles, ...) {
UseMethod("task_check_col_roles")
}

#' @rdname task_check_col_roles
#' @export
task_check_col_roles.Task = function(task, new_roles, ...) {
for (role in c("group", "weight", "name")) {
if (length(new_roles[[role]]) > 1L) {
stopf("There may only be up to one column with role '%s'", role)
Expand All @@ -1197,29 +1216,75 @@ task_check_col_roles = function(self, new_roles) {

# check weights
if (length(new_roles[["weight"]])) {
weights = self$backend$data(self$backend$rownames, cols = new_roles[["weight"]])
weights = task$backend$data(task$backend$rownames, cols = new_roles[["weight"]])
assert_numeric(weights[[1L]], lower = 0, any.missing = FALSE, .var.name = names(weights))
}

# check name
if (length(new_roles[["name"]])) {
row_names = self$backend$data(self$backend$rownames, cols = new_roles[["name"]])
row_names = task$backend$data(task$backend$rownames, cols = new_roles[["name"]])
if (!is.character(row_names[[1L]]) && !is.factor(row_names[[1L]])) {
stopf("Assertion on '%s' failed: Must be of type 'character' or 'factor', not %s", names(row_names), class(row_names[[1]]))
}
}

if (inherits(self, "TaskSupervised")) {
if (length(new_roles$target) == 0L) {
stopf("Supervised tasks need at least one target column")
}
} else if (inherits(self, "TaskUnsupervised")) {
if (length(new_roles$target) != 0L) {
stopf("Unsupervised tasks may not have a target column")
}
return(new_roles)
}

#' @rdname task_check_col_roles
#' @export
task_check_col_roles.TaskClassif = function(task, new_roles, ...) {

# check target
if (length(new_roles[["target"]]) > 1L) {
stopf("There may only be up to one column with role 'target'")
}

if (length(new_roles[["target"]]) && any(fget(task$col_info, new_roles[["target"]], "type", key = "id") %nin% c("factor", "ordered"))) {
stopf("Target column(s) %s must be a factor or ordered factor", paste0("'", new_roles[["target"]], "'", collapse = ","))
}

NextMethod()
}

#' @rdname task_check_col_roles
#' @export
task_check_col_roles.TaskRegr = function(task, new_roles, ...) {

# check target
if (length(new_roles[["target"]]) > 1L) {
stopf("There may only be up to one column with role 'target'")
}

if (length(new_roles[["target"]]) && any(fget(task$col_info, new_roles[["target"]], "type", key = "id") %nin% c("numeric", "integer"))) {
stopf("Target column '%s' must be a numeric or integer column", paste0("'", new_roles[["target"]], "'", collapse = ","))
}

NextMethod()
}

#' @rdname task_check_col_roles
#' @export
task_check_col_roles.TaskSupervised = function(task, new_roles, ...) {

# check target
if (length(new_roles$target) == 0L) {
stopf("Supervised tasks need at least one target column")
}

NextMethod()
}

#' @rdname task_check_col_roles
#' @export
task_check_col_roles.TaskUnsupervised = function(task, new_roles, ...) {

# check target
if (length(new_roles$target) != 0L) {
stopf("Unsupervised tasks may not have a target column")
}

new_roles
NextMethod()
}

#' @title Column Information for Backend
Expand Down
64 changes: 64 additions & 0 deletions R/default_fallback.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#' @title Create a Fallback Learner
#'
#' @description
#' Create a fallback learner for a given learner.
#' The function searches for a suitable fallback learner based on the task type.
#' Additional checks are performed to ensure that the fallback learner supports the predict type.
#'
#' @param learner [Learner]\cr
#' The learner for which a fallback learner should be created.
#' @param ... `any`\cr
#' ignored.
#'
#' @return [Learner]
default_fallback = function(learner, ...) {
UseMethod("default_fallback")
}

#' @rdname default_fallback
#' @export
default_fallback.Learner = function(learner, ...) {
# FIXME: remove when new encapsulate/fallback system is in place
return(NULL)
}

#' @rdname default_fallback
#' @export
default_fallback.LearnerClassif = function(learner) {
fallback = lrn("classif.featureless")

# set predict type
if (learner$predict_type %nin% fallback$predict_types) {
stopf("Fallback learner '%s' does not support predict type '%s'.", fallback_id, learner$predict_type)
}

fallback$predict_type = learner$predict_type

return(fallback)
}

#' @rdname default_fallback
#' @export
default_fallback.LearnerRegr = function(learner) {
fallback = lrn("regr.featureless")

# set predict type
if (learner$predict_type %nin% fallback$predict_types) {
stopf("Fallback learner '%s' does not support predict type '%s'.", fallback$id, learner$predict_type)
}

fallback$predict_type = learner$predict_type

# set quantiles
if (learner$predict_type == "quantiles") {

if (is.null(learner$quantiles) || is.null(learner$quantile_response)) {
stopf("Cannot set quantiles for fallback learner. Set `$quantiles` and `$quantile_response` in %s.", learner$id)
}

fallback$quantiles = learner$quantiles
fallback$quantile_response = learner$quantile_response
}

return(fallback)
}
5 changes: 0 additions & 5 deletions R/mlr_reflections.R
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,6 @@ local({
regr = list(response = "response", se = c("response", "se"), quantiles = c("response", "quantiles"), distr = c("response", "se", "distr"))
)

mlr_reflections$learner_fallback = list(
classif = "classif.featureless",
regr = "regr.featureless"
)

# Allowed tags for parameters
mlr_reflections$learner_param_tags = c("train", "predict", "hotstart", "importance", "threads", "required", "internal_tuning")

Expand Down
3 changes: 2 additions & 1 deletion R/mlr_test_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#' the task, learner and prediction of the returned `result`.
#'
#' For example usages you can look at the autotests in various mlr3 source repositories such as mlr3learners.
#' More information can be found in the `inst/testthat/autotest.R` file.
#'
#' **Parameters**:
#'
Expand All @@ -42,7 +43,7 @@
#' Whether to check that running the learner twice with the same seed should result in identical predictions.
#' Default is `TRUE`.
#' * `configure_learner` (`function(learner, task)`)\cr
#' Before running a `learner` on a `task`, this function allows to change its parameter values depending on the input task.
#' Before running a `learner` on a `task`, this function allows to change its parameter values depending on the input task.
#'
#' @section run_paramtest():
#'
Expand Down
45 changes: 0 additions & 45 deletions R/set_fallback.R

This file was deleted.

14 changes: 7 additions & 7 deletions R/task_converters.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ convert_task = function(intask, target = NULL, new_type = NULL, drop_original_ta
# copy row_roles / col_roles / properties
newtask$row_roles = intask$row_roles
props = intersect(mlr_reflections$task_col_roles[[intask$task_type]], mlr_reflections$task_col_roles[[new_type]])
newtask$col_roles[props] = intask$col_roles[props]
newtask$set_col_roles(target, "target")

# Add the original target(s) as features, only keeping 'new_target'.
if (!all(intask$target_names == target)) {
newtask$set_col_roles(setdiff(intask$col_roles$target, target), "feature")
}
col_roles = intask$col_roles[props]
# add the original target(s) as features, only keeping 'new_target'
col_roles$feature = c(col_roles$feature, setdiff(intask$col_roles$target, target))
col_roles$target = target
# remove new target from features
col_roles$feature = setdiff(col_roles$feature, target)
newtask$col_roles[props] = col_roles

# during prediction, when target is NA, we do not call droplevels
if (assert_flag(drop_levels)) {
Expand Down
Loading

0 comments on commit 00cd22e

Please sign in to comment.