From d14598af3ccb2504532421290ffd5f2188d87a73 Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 1 Nov 2024 11:00:19 +0100 Subject: [PATCH] chore: check if resampling was instantiated on task --- R/benchmark.R | 4 ++++ R/resample.R | 5 +++++ tests/testthat/test_benchmark.R | 12 ++++++++++++ tests/testthat/test_resample.R | 10 ++++++++++ 4 files changed, 31 insertions(+) diff --git a/R/benchmark.R b/R/benchmark.R index 5f26c38a6..c132e1fd9 100644 --- a/R/benchmark.R +++ b/R/benchmark.R @@ -128,6 +128,10 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps # learner = assert_learner(as_learner(learner, clone = TRUE)) assert_learnable(task, learner) + if (resampling$task_hash != task$hash) { + stopf("Resampling '%s' was not instantiated with task '%s'", resampling$id, task$id) + } + iters = resampling$iters n_params = max(1L, length(param_values)) # insert constant values diff --git a/R/resample.R b/R/resample.R index 0ef5b41fe..6d455d925 100644 --- a/R/resample.R +++ b/R/resample.R @@ -70,6 +70,11 @@ resample = function(task, learner, resampling, store_models = FALSE, store_backe if (!resampling$is_instantiated) { resampling = resampling$instantiate(task) } + + if (resampling$task_hash != task$hash) { + stopf("Resampling '%s' was not instantiated with task '%s'", resampling$id, task$id) + } + n = resampling$iters pb = if (isNamespaceLoaded("progressr")) { # NB: the progress bar needs to be created in this env diff --git a/tests/testthat/test_benchmark.R b/tests/testthat/test_benchmark.R index d6e4a3e94..d4f712e89 100644 --- a/tests/testthat/test_benchmark.R +++ b/tests/testthat/test_benchmark.R @@ -581,3 +581,15 @@ test_that("score works with predictions and empty predictions", { expect_warning({tab = bmr$score(msr("classif.ce", predict_sets = "test"))}, "Measure") expect_equal(tab$classif.ce[1], NaN) }) + +test_that("resampling was instantiated on the task", { + learner = lrn("classif.rpart") + task = tsk("pima") + resampling = rsmp("cv", folds = 5) + resampling$instantiate(task) + task = tsk("spam") + + design = data.table(task = list(task), learner = list(learner), resampling = list(resampling)) + + expect_error(benchmark(design), "not instantiated") +}) diff --git a/tests/testthat/test_resample.R b/tests/testthat/test_resample.R index 77b047ad3..beae14fdb 100644 --- a/tests/testthat/test_resample.R +++ b/tests/testthat/test_resample.R @@ -510,3 +510,13 @@ test_that("predict_time is 0 if no predict_set is specified", { times = rr$score(msr("time_predict"))$time_predict expect_true(all(times == 0)) }) + +test_that("resampling was instantiated on the task", { + learner = lrn("classif.rpart") + task = tsk("pima") + resampling = rsmp("cv", folds = 5) + resampling$instantiate(task) + task = tsk("spam") + + expect_error(resample(task, learner, resampling), "not instantiated") +})