Skip to content

Commit

Permalink
Fix/predict newdata (#1240)
Browse files Browse the repository at this point in the history
* ci: fail on note

* fix(predict): type conversion when predicting on new data

* ...

---------

Co-authored-by: Marc Becker <[email protected]>
  • Loading branch information
sebffischer and be-marc authored Jan 6, 2025
1 parent 9c95317 commit 5c24ba8
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 9 deletions.
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# mlr3 (development version)

* fix: the `$predict_newdata()` method of `Learner` now automatically conducts type conversions if the input is a `data.frame` (#685)
* BREAKING_CHANGE: Predicting on a `task` with the wrong column information is now an error and not a warning.
* Column names with UTF-8 characters are now allowed by default.
The option `mlr3.allow_utf8_names` is removed.
The option `mlr3.allow_utf8_names` is removed.
* BREAKING CHANGE: `Learner$predict_types` is read-only now.
* docs: Clear up behavior of `Learner$predict_type` after training.

Expand Down
11 changes: 11 additions & 0 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,8 @@ Learner = R6Class("Learner",
#' `data.frame()` or [DataBackend].
#' If a [DataBackend] is provided as `newdata`, the row ids are preserved,
#' otherwise they are set to to the sequence `1:nrow(newdata)`.
#' If the input is a `data.frame`, [`auto_convert`] is used for type-conversions to ensure compatability
#' of features between `$train()` and `$predict()`.
#'
#' @param task ([Task]).
#'
Expand All @@ -393,6 +395,14 @@ Learner = R6Class("Learner",
task = task_rm_backend(task)
}

if (is.data.frame(newdata)) {
keep_cols = intersect(names(newdata), task$col_info$id)
ci = task$col_info[list(keep_cols), on = "id"]
newdata = do.call(data.table, Map(auto_convert,
value = as.list(newdata)[ci$id],
id = ci$id, type = ci$type, levels = ci$levels))
}

newdata = as_data_backend(newdata)
assert_names(newdata$colnames, must.include = task$feature_names)

Expand All @@ -409,6 +419,7 @@ Learner = R6Class("Learner",

# do some type conversions if necessary
task$backend = newdata
task$col_info = col_info(task$backend)
task$row_roles$use = task$backend$rownames
self$predict(task)
},
Expand Down
2 changes: 1 addition & 1 deletion R/assertions.R
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ assert_predictable = function(task, learner) {
all(pmap_lgl(list(x = ci_train$levels, y = ci_predict$levels), identical))

if (!ok) {
lg$warn("Learner '%s' received task with different column info (feature type or level ordering) during train and predict.", learner$id)
stopf("Learner '%s' received task with different column info (feature type or factor level ordering) during train and predict.", learner$id)
}
}

Expand Down
4 changes: 3 additions & 1 deletion man/Learner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 21 additions & 6 deletions tests/testthat/test_Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -255,13 +255,15 @@ test_that("learner cannot be trained with TuneToken present", {

test_that("integer<->numeric conversion in newdata (#533)", {
data = data.table(y = runif(10), x = 1:10)
newdata = data.table(y = runif(10), x = 1:10 + 0.1)
newdata1 = data.table(y = runif(10), x = as.double(1:10))
newdata2 = data.table(y = runif(10), x = 1:10 + 0.1)

task = TaskRegr$new("test", data, "y")
learner = lrn("regr.featureless")
learner$train(task)
expect_prediction(learner$predict_newdata(data))
expect_prediction(learner$predict_newdata(newdata))
expect_prediction(learner$predict_newdata(newdata1))
expect_error(learner$predict_newdata(newdata2), "failed to convert from class 'numeric'")
})

test_that("weights", {
Expand Down Expand Up @@ -575,10 +577,7 @@ test_that("column info is compared during predict", {
task_other = as_task_classif(dother, target = "y")
l = lrn("classif.rpart")
l$train(task)
old_threshold = lg$threshold
lg$set_threshold("warn")
expect_output(l$predict(task_flip), "task with different column info")
lg$set_threshold(old_threshold)
expect_error(l$predict(task_flip), "task with different column info")
expect_error(l$predict(task_other), "with different columns")
})

Expand Down Expand Up @@ -663,3 +662,19 @@ test_that("configure method works", {
expect_equal(learner$param_set$values$xval, 10)
expect_equal(learner$predict_sets, "train")
})

test_that("predict_newdata auto conversion (#685)", {
l = lrn("classif.debug", save_tasks = TRUE)$train(tsk("iris")$select(c("Sepal.Length", "Sepal.Width")))
expect_error(l$predict_newdata(data.table(Sepal.Length = 1, Sepal.Width = "abc")),
"Incompatible types during auto-converting column 'Sepal.Width'", fixed = TRUE)
expect_error(l$predict_newdata(data.table(Sepal.Length = 1L)),
"but is missing elements")

# New test for integerish value conversion to double
p1 = l$predict_newdata(data.table(Sepal.Length = 1, Sepal.Width = 2))
p2 = l$predict_newdata(data.table(Sepal.Length = 1L, Sepal.Width = 2))
expect_equal(l$model$task_predict$col_info[list("Sepal.Length")]$type, "numeric")
expect_double(l$model$task_predict$data(cols = "Sepal.Length")[[1]])

expect_equal(p1, p2)
})

0 comments on commit 5c24ba8

Please sign in to comment.