Skip to content

Commit

Permalink
Merge branch 'main' into callback
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Dec 19, 2024
2 parents 9be497f + aae5342 commit 2b3dc85
Show file tree
Hide file tree
Showing 26 changed files with 151 additions and 42 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:

- name: Deploy
if: github.event_name != 'pull_request'
uses: JamesIves/github-pages-deploy-action@v4.6.9
uses: JamesIves/github-pages-deploy-action@v4.7.2
with:
clean: false
branch: gh-pages
Expand Down
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Suggests:
remotes,
RhpcBLASctl,
rpart,
testthat (>= 3.1.0)
testthat (>= 3.2.0)
Encoding: UTF-8
Config/testthat/edition: 3
Config/testthat/parallel: false
Expand Down
44 changes: 44 additions & 0 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,50 @@ Learner = R6Class("Learner",
private$.encapsulation = c(train = method, predict = method)
private$.fallback = fallback

return(invisible(self))
},

#' @description
#' Sets parameter values and fields of the learner.
#' All arguments whose names match the name of a parameter of the [paradox::ParamSet] are set as parameters.
#' All remaining arguments are assumed to be regular fields.
#'
#' @param ... (named `any`)\cr
#' Named arguments to set parameter values and fields.
#' @param .values (named `any`)\cr
#' Named list of parameter values and fields.
configure = function(..., .values = list()) {
dots = list(...)
assert_list(dots, names = "unique")
assert_list(.values, names = "unique")
assert_disjunct(names(dots), names(.values))
new_values = insert_named(dots, .values)

# set params in ParamSet
if (length(new_values)) {
param_ids = self$param_set$ids()
ii = names(new_values) %in% param_ids
if (any(ii)) {
self$param_set$values = insert_named(self$param_set$values, new_values[ii])
new_values = new_values[!ii]
}
} else {
param_ids = character()
}

# remaining args go into fields
if (length(new_values)) {
ndots = names(new_values)
for (i in seq_along(new_values)) {
nn = ndots[[i]]
if (!exists(nn, envir = self, inherits = FALSE)) {
stopf("Cannot set argument '%s' for '%s' (not a parameter, not a field).%s",
nn, class(self)[1L], did_you_mean(nn, c(param_ids, setdiff(names(self), ".__enclos_env__")))) # nolint
}
self[[nn]] = new_values[[i]]
}
}

return(invisible(self))
}
),
Expand Down
2 changes: 1 addition & 1 deletion inst/testthat/helper_expectations.R
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ expect_measure = function(m) {
testthat::expect_output(print(m), "Measure")

if ("requires_no_prediction" %in% m$properties) {
testthat::expect_true(is.null(m$predict_sets))
testthat::expect_null(m$predict_sets)
}

expect_id(m$id)
Expand Down
24 changes: 24 additions & 0 deletions man/Learner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/LearnerClassif.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/LearnerRegr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_learners_classif.debug.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_learners_classif.featureless.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_learners_classif.rpart.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_learners_regr.debug.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_learners_regr.featureless.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_learners_regr.rpart.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion tests/testthat/teardown.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
options(old_opts)
lg$set_threshold(old_threshold)
future::plan(old_plan)
file.remove("tests/testthat/Rplots.pdf")
6 changes: 3 additions & 3 deletions tests/testthat/test_DataBackendRbind.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ test_that("DataBackendRbind", {


# all col-hashes are mutually disjoint
expect_true(length(intersect(b1$col_hashes, b2$col_hashes)) == 0)
expect_true(length(intersect(b$col_hashes, b1$col_hashes)) == 0)
expect_true(length(intersect(b$col_hashes, b2$col_hashes)) == 0)
expect_length(intersect(b1$col_hashes, b2$col_hashes), 0)
expect_length(intersect(b$col_hashes, b1$col_hashes), 0)
expect_length(intersect(b$col_hashes, b2$col_hashes), 0)

})

Expand Down
56 changes: 45 additions & 11 deletions tests/testthat/test_Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ test_that("Learners are called with invoke / small footprint of call", {
learner$train(task)
call = as.character(learner$model$call)
expect_character(call, min.len = 1L, any.missing = FALSE)
expect_true(any(grepl("task$formula()", call, fixed = TRUE)))
expect_true(any(grepl("task$data", call, fixed = TRUE)))
expect_match(call, "task$formula()", fixed = TRUE, all = FALSE)
expect_match(call, "task$data", fixed = TRUE, all = FALSE)
expect_lt(sum(nchar(call)), 1000)
})

Expand Down Expand Up @@ -236,7 +236,7 @@ test_that("empty predict set (#421)", {
learner$train(task, hout$train_set(1))
pred = learner$predict(task, hout$test_set(1))
expect_prediction(pred)
expect_true(any(grepl("No data to predict on", learner$log$msg)))
expect_match(learner$log$msg, "No data to predict on", all = FALSE)
})

test_that("fallback learner is deep cloned (#511)", {
Expand Down Expand Up @@ -330,7 +330,7 @@ test_that("validation task's backend is removed", {
task = tsk("mtcars")
task$internal_valid_task = 1:10
learner$train(task)
expect_true(is.null(learner$state$train_task$internal_valid_task$backend))
expect_null(learner$state$train_task$internal_valid_task$backend)
})

test_that("manual $train() stores validation hash and validation ids", {
Expand All @@ -348,7 +348,7 @@ test_that("manual $train() stores validation hash and validation ids", {
# nothing is stored for learners that don't do it
l2 = lrn("classif.featureless")
l2$train(task)
expect_true(is.null(l2$state$internal_valid_task_hash))
expect_null(l2$state$internal_valid_task_hash)
})

test_that("error when training a learner that sets valiadte to 'predefined' on a task without a validation task", {
Expand Down Expand Up @@ -421,15 +421,15 @@ test_that("internal_valid_task is created correctly", {
task$internal_valid_task = partition(task)$test
learner$train(task)
learner$validate = NULL
expect_true(is.null(learner$internal_valid_scores))
expect_true(is.null(learner$task$internal_valid_task))
expect_null(learner$internal_valid_scores)
expect_null(learner$task$internal_valid_task)

# validate = NULL (but task has none)
learner1 = LearnerClassifTest$new()
task1 = tsk("iris")
learner1$train(task1)
expect_true(is.null(learner1$internal_valid_scores))
expect_true(is.null(learner1$task$internal_valid_task))
expect_null(learner1$internal_valid_scores)
expect_null(learner1$task$internal_valid_task)

# validate = "test"
LearnerClassifTest2 = R6Class("LearnerClassifTest2", inherit = LearnerClassifDebug,
Expand All @@ -455,7 +455,7 @@ test_that("internal_valid_task is created correctly", {
resampling = rsmp("holdout")$instantiate(task2)
learner2$expected_valid_ids = resampling$test_set(1)
learner2$expected_train_ids = resampling$train_set(1)
expect_error(resample(task2, learner2, resampling), regexp = NA)
expect_no_error(resample(task2, learner2, resampling))

# ratio works
LearnerClassifTest3 = R6Class("LearnerClassifTest3", inherit = LearnerClassifDebug,
Expand All @@ -477,7 +477,7 @@ test_that("internal_valid_task is created correctly", {
learner4 = lrn("classif.debug", validate = 0.2)
task = tsk("iris")
learner4$train(task)
expect_true(is.null(task$internal_valid_task))
expect_null(task$internal_valid_task)
})

test_that("compatability check on validation task", {
Expand Down Expand Up @@ -629,3 +629,37 @@ test_that("predict time is cumulative", {
t2 = learner$timings["predict"]
expect_true(t1 > t2)
})

test_that("configure method works", {
learner = lrn("classif.rpart")

expect_learner(learner$configure())
expect_learner(learner$configure(.values = list()))

# set new hyperparameter value
learner$configure(cp = 0.1)
expect_equal(learner$param_set$values$cp, 0.1)

# overwrite existing hyperparameter value
learner$configure(xval = 10)
expect_equal(learner$param_set$values$xval, 10)

# set field
learner$configure(predict_sets = "train")
expect_equal(learner$predict_sets, "train")

# hyperparameter and field
learner$configure(minbucket = 2, parallel_predict = TRUE)
expect_equal(learner$param_set$values$minbucket, 2)
expect_true(learner$parallel_predict)

# unknown hyperparameter and field
expect_error(learner$configure(xvald = 1), "Cannot set argument")

# use .values
learner = lrn("classif.rpart")
learner$configure(.values = list(cp = 0.1, xval = 10, predict_sets = "train"))
expect_equal(learner$param_set$values$cp, 0.1)
expect_equal(learner$param_set$values$xval, 10)
expect_equal(learner$predict_sets, "train")
})
4 changes: 2 additions & 2 deletions tests/testthat/test_Measure.R
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ test_that("primary iters are respected", {

jaccard = msr("sim.jaccard")
expect_error(rr1$aggregate(jaccard), "primary_iters")
expect_error(rr2$aggregate(jaccard), NA)
expect_no_error(rr2$aggregate(jaccard))
jaccard$properties = c(jaccard$properties, "primary_iters")
x1 = rr1$aggregate(jaccard)
x2 = rr3$aggregate(jaccard)
Expand All @@ -176,7 +176,7 @@ test_that("primary iters are respected", {

test_that("no predict_sets required (#1094)", {
m = msr("internal_valid_score")
expect_equal(m$predict_sets, NULL)
expect_null(m$predict_sets)
rr = resample(tsk("iris"), lrn("classif.debug", validate = 0.3, predict_sets = NULL), rsmp("holdout"))
expect_double(rr$aggregate(m))
expect_warning(rr$aggregate(msr("classif.ce")), "needs predict sets")
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_Task.R
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ test_that("internal_valid_task is printed", {
task = tsk("iris")
task$internal_valid_task = c(1:10, 51:60, 101:110)
out = capture_output(print(task))
expect_true(grepl(pattern = "* Validation Task: (30x5)", fixed = TRUE, x = out))
expect_match(out, "* Validation Task: (30x5)", fixed = TRUE)
})

test_that("task hashes during resample", {
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test_as_learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ test_that("as_learner conversion", {
test_that("discard_state", {
learner = lrn("classif.rpart")$train(tsk("iris"))
learner2 = as_learner(learner, clone = TRUE, discard_state = TRUE)
expect_true(is.null(learner2$state))
expect_null(learner2$state)
expect_false(is.null(learner$state))

learner3 = lrn("classif.rpart")
as_learner(learner3, clone = FALSE, discard_state = TRUE)
expect_true(is.null(learner3$state))
expect_null(learner3$state)
})
Loading

0 comments on commit 2b3dc85

Please sign in to comment.