Skip to content

Commit

Permalink
Merge branch 'main' into resampling_converters
Browse files Browse the repository at this point in the history
  • Loading branch information
mllg committed Sep 4, 2024
2 parents 8d58c54 + 55f4a03 commit 52101aa
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 5 deletions.
4 changes: 2 additions & 2 deletions R/MeasureRegrRSQ.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
#' @template seealso_measure
#' @export
MeasureRegrRSQ = R6Class("MeasureRSQ",
inherit = Measure,
inherit = MeasureRegr,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
Expand All @@ -40,10 +40,10 @@ MeasureRegrRSQ = R6Class("MeasureRSQ",

super$initialize(
id = "rsq",
task_type = "regr",
properties = if (!private$.pred_set_mean) c("requires_task", "requires_train_set") else character(0),
predict_type = "response",
minimize = FALSE,
range = c(-Inf, 1),
man = "mlr3::mlr_measures_regr.rsq"
)
}
Expand Down
2 changes: 2 additions & 0 deletions R/as_prediction.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#' @return [Prediction].
#' @export
as_prediction = function(x, check = FALSE, ...) {
if (is.null(x)) return(list())

UseMethod("as_prediction")
}

Expand Down
4 changes: 4 additions & 0 deletions R/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,10 @@ workhorse = function(iteration, task, learner, resampling, param_values = NULL,
lg$debug("Creating Prediction for predict set '%s'", set)
learner_predict(learner, task, row_ids)
}, set = predict_sets, row_ids = pred_data$sets, task = pred_data$tasks)

if (!length(predict_sets)) {
learner$state$predict_time = 0L
}
pdatas = discard(pdatas, is.null)

# set the model slot after prediction so it can be sent back to the main process
Expand Down
6 changes: 3 additions & 3 deletions man/mlr_measures_regr.rsq.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions tests/testthat/test_benchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -567,3 +567,16 @@ test_that("predictions retrieved with as.data.table and predictions method are e
predictions = unlist(map(bmr$resample_results$resample_result, function(rr) rr$predictions(predict_sets = "train")), recursive = FALSE)
expect_equal(tab$prediction, predictions)
})

test_that("score works with predictions and empty predictions", {
learner_1 = lrn("classif.rpart", predict_sets = "train", id = "learner_1")
learner_2 = lrn("classif.rpart", predict_sets = "test", id = "learner_2")
task = tsk("pima")

design = benchmark_grid(task, list(learner_1, learner_2), rsmp("holdout"))

bmr = benchmark(design)

expect_warning({tab = bmr$score(msr("classif.ce", predict_sets = "test"))}, "Measure")
expect_equal(tab$classif.ce[1], NaN)
})
13 changes: 13 additions & 0 deletions tests/testthat/test_resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,9 @@ test_that("resample result works with not predicted predict set", {
tab = as.data.table(rr)
expect_list(tab$prediction, len = 1)
expect_list(tab$prediction[[1]], len = 0)

expect_warning({tab = rr$score(msr("classif.ce", predict_sets = "test"))}, "Measure")
expect_equal(tab$classif.ce, NaN)
})

test_that("resample results works with no predicted predict set", {
Expand All @@ -489,4 +492,14 @@ test_that("resample results works with no predicted predict set", {
tab = as.data.table(rr)
expect_list(tab$prediction, len = 1)
expect_list(tab$prediction[[1]], len = 0)

expect_warning({tab = rr$score(msr("classif.ce", predict_sets = "test"))}, "Measure")
expect_equal(tab$classif.ce, NaN)
})

test_that("predict_time is 0 if no predict_set is specified", {
learner = lrn("classif.featureless", predict_sets = NULL)
rr = resample(task, learner, resampling)
times = rr$score(msr("time_predict"))$time_predict
expect_true(all(times == 0))
})

0 comments on commit 52101aa

Please sign in to comment.