Skip to content

Commit

Permalink
Merge branch 'main' into cli
Browse files Browse the repository at this point in the history
  • Loading branch information
lona-k committed Nov 25, 2024
2 parents 5126f77 + 282b53a commit e01119d
Show file tree
Hide file tree
Showing 20 changed files with 163 additions and 35 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:

- name: Deploy
if: github.event_name != 'pull_request'
uses: JamesIves/[email protected].8
uses: JamesIves/[email protected].9
with:
clean: false
branch: gh-pages
Expand Down
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3
Title: Machine Learning in R - Next Generation
Version: 0.21.1.9000
Version: 0.22.0.9000
Authors@R:
c(
person("Michel", "Lang", , "[email protected]", role = "aut",
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
# mlr3 (development version)

# mlr3 0.22.0

* fix: Quantiles must not ascend with probabilities.
* refactor: Replace `tsk("boston_housing")` with `tsk("california_housing")`.
* feat: Require unique learner ids in `benchmark_grid()`.
* BREAKING CHANGE: Remove ``$loglik()`` method from all learners.
* fix: Ignore `future.globals.maxSize` when `future::plan("sequential")` is used.
* feat: Add `$characteristics` field to `Task` to store additional information.

# mlr3 0.21.1

Expand Down
14 changes: 11 additions & 3 deletions R/BenchmarkResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#' @template param_measures
#'
#' @section S3 Methods:
#' * `as.data.table(rr, ..., reassemble_learners = TRUE, convert_predictions = TRUE, predict_sets = "test")`\cr
#' * `as.data.table(rr, ..., reassemble_learners = TRUE, convert_predictions = TRUE, predict_sets = "test", task_characteristics = FALSE)`\cr
#' [BenchmarkResult] -> [data.table::data.table()]\cr
#' Returns a tabular view of the internal data.
#' * `c(...)`\cr
Expand Down Expand Up @@ -544,9 +544,17 @@ BenchmarkResult = R6Class("BenchmarkResult",
)

#' @export
as.data.table.BenchmarkResult = function(x, ..., hashes = FALSE, predict_sets = "test") { # nolint
as.data.table.BenchmarkResult = function(x, ..., hashes = FALSE, predict_sets = "test", task_characteristics = FALSE) { # nolint
assert_flag(task_characteristics)
tab = get_private(x)$.data$as_data_table(view = NULL, predict_sets = predict_sets)
tab[, c("uhash", "task", "learner", "resampling", "iteration", "prediction"), with = FALSE]
tab = tab[, c("uhash", "task", "learner", "resampling", "iteration", "prediction"), with = FALSE]

if (task_characteristics) {
set(tab, j = "characteristics", value = map(tab$task, "characteristics"))
tab = unnest(tab, "characteristics")
}

tab[]
}

#' @export
Expand Down
3 changes: 0 additions & 3 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,6 @@
#' * `oob_error(...)`: Returns the out-of-bag error of the model as `numeric(1)`.
#' The learner must be tagged with property `"oob_error"`.
#'
#' * `loglik(...)`: Extracts the log-likelihood (c.f. [stats::logLik()]).
#' This can be used in measures like [mlr_measures_aic] or [mlr_measures_bic].
#'
#' * `internal_valid_scores`: Returns the internal validation score(s) of the model as a named `list()`.
#' Only available for [`Learner`]s with the `"validation"` property.
#' If the learner is not trained yet, this returns `NULL`.
Expand Down
17 changes: 5 additions & 12 deletions R/Measure.R
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ Measure = R6Class("Measure",
assert_measure(self, task = task, learner = learner, prediction = prediction)
assert_prediction(prediction, null.ok = "requires_no_prediction" %nin% self$properties)

# check should be added to assert_measure()
# except when the checks are superfluous for rr$score() and bmr$score()
# these checks should be added bellow
if ("requires_task" %in% self$properties && is.null(task)) {
stopf("Measure '%s' requires a task", self$id)
}
Expand All @@ -206,21 +209,14 @@ Measure = R6Class("Measure",
stopf("Measure '%s' requires a learner", self$id)
}

if ("requires_model" %in% self$properties && (is.null(learner) || is.null(learner$model))) {
stopf("Measure '%s' requires the trained model", self$id)
}
if ("requires_model" %in% self$properties && is_marshaled_model(learner$model)) {
stopf("Measure '%s' requires the trained model, but model is in marshaled form", self$id)
if (!is_scalar_na(self$task_type) && self$task_type != prediction$task_type) {
stopf("Measure '%s' incompatible with task type '%s'", self$id, prediction$task_type)
}

if ("requires_train_set" %in% self$properties && is.null(train_set)) {
stopf("Measure '%s' requires the train_set", self$id)
}

if (!is_scalar_na(self$task_type) && self$task_type != prediction$task_type) {
stopf("Measure '%s' incompatible with task type '%s'", self$id, prediction$task_type)
}

score_single_measure(self, task, learner, train_set, prediction)
},

Expand Down Expand Up @@ -359,8 +355,6 @@ score_single_measure = function(measure, task, learner, train_set, prediction) {
return(NaN)
}



if (!is_scalar_na(measure$predict_type) && measure$predict_type %nin% prediction$predict_types) {
# TODO lgr$debug()
return(NaN)
Expand All @@ -371,7 +365,6 @@ score_single_measure = function(measure, task, learner, train_set, prediction) {
return(NaN)
}


get_private(measure)$.score(prediction = prediction, task = task, learner = learner, train_set = train_set)
}

Expand Down
12 changes: 7 additions & 5 deletions R/MeasureAIC.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ MeasureAIC = R6Class("MeasureAIC",
private = list(
.score = function(prediction, learner, ...) {
learner = learner$base_learner()
if ("loglik" %nin% learner$properties) {
return(NA_real_)
}

k = self$param_set$values$k %??% 2
return(stats::AIC(learner$loglik(), k = k))

tryCatch({
return(stats::AIC(stats::logLik(learner$model), k = k))
}, error = function(e) {
warningf("Learner '%s' does not support AIC calculation", learner$id)
return(NA_real_)
})
}
)
)
Expand Down
10 changes: 6 additions & 4 deletions R/MeasureBIC.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ MeasureBIC = R6Class("MeasureBIC",
private = list(
.score = function(prediction, learner, ...) {
learner = learner$base_learner()
if ("loglik" %nin% learner$properties) {
return(NA_real_)
}

return(stats::BIC(learner$loglik()))
tryCatch({
return(stats::BIC(stats::logLik(learner$model)))
}, error = function(e) {
warningf("Learner '%s' does not support BIC calculation", learner$id)
return(NA_real_)
})
}
)
)
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureInternalValidScore.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ MeasureInternalValidScore = R6Class("MeasureInternalValidScore",
super$initialize(
id = select %??% "internal_valid_score",
task_type = NA_character_,
properties = c("na_score", "requires_model", "requires_learner", "requires_no_prediction"),
properties = c("na_score", "requires_learner", "requires_no_prediction"),
predict_sets = NULL,
predict_type = NA_character_,
range = c(-Inf, Inf),
Expand Down
16 changes: 16 additions & 0 deletions R/Task.R
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ Task = R6Class("Task",
if (!is.null(private$.internal_valid_task)) {
cli_li(sprintf("Validation Task: (%ix%i)", private$.internal_valid_task$nrow, private$.internal_valid_task$ncol))
}

if (!is.null(self$characteristics)) {
catf(str_indent("* Characteristics: ", as_short_string(self$characteristics)))
}
},

#' @description
Expand Down Expand Up @@ -1137,6 +1141,17 @@ Task = R6Class("Task",
private$.col_hashes = self$backend$col_hashes[setdiff(unlist(private$.col_roles, use.names = FALSE), self$backend$primary_key)]
}
private$.col_hashes
},

#' @field characteristics (`list()`)\cr
#' List of characteristics of the task, e.g. `list(n = 5, p = 7)`.
characteristics = function(rhs) {
if (missing(rhs)) {
return(private$.characteristics)
}

private$.characteristics = assert_list(rhs, null.ok = TRUE)
private$.hash = NULL
}
),

Expand All @@ -1148,6 +1163,7 @@ Task = R6Class("Task",
.row_roles = NULL,
.hash = NULL,
.col_hashes = NULL,
.characteristics = NULL,

deep_clone = function(name, value) {
# NB: DataBackends are never copied!
Expand Down
10 changes: 10 additions & 0 deletions R/assertions.R
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ assert_measure = function(measure, task = NULL, learner = NULL, prediction = NUL
assert_class(measure, "Measure", .var.name = .var.name)

if (!is.null(task)) {

if (!is_scalar_na(measure$task_type) && !test_matching_task_type(task$task_type, measure, "measure")) {
stopf("Measure '%s' is not compatible with type '%s' of task '%s'",
measure$id, task$task_type, task$id)
Expand All @@ -221,6 +222,15 @@ assert_measure = function(measure, task = NULL, learner = NULL, prediction = NUL
}

if (!is.null(learner)) {

if ("requires_model" %in% measure$properties && is.null(learner$model)) {
stopf("Measure '%s' requires the trained model", measure$id)
}

if ("requires_model" %in% measure$properties && is_marshaled_model(learner$model)) {
stopf("Measure '%s' requires the trained model, but model is in marshaled form", measure$id)
}

if (!is_scalar_na(measure$task_type) && measure$task_type != learner$task_type) {
stopf("Measure '%s' is not compatible with type '%s' of learner '%s'",
measure$id, learner$task_type, learner$id)
Expand Down
7 changes: 7 additions & 0 deletions R/helper_exec.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ future_map = function(n, FUN, ..., MoreArgs = list()) {
}
stdout = if (is_sequential) NA else TRUE

# workaround for sequential plan checking the size of the globals
# see https://github.com/futureverse/future/issues/197
if (is_sequential) {
old_opts = options(future.globals.maxSize = Inf)
on.exit(options(old_opts), add = TRUE)
}

MoreArgs = c(MoreArgs, list(is_sequential = is_sequential))

lg$debug("Running resample() via future with %i iterations", n)
Expand Down
3 changes: 2 additions & 1 deletion R/helper_hashes.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,6 @@ task_hash = function(task, use_ids, test_ids = NULL, ignore_internal_valid_task
use_ids,
task$col_roles,
get_private(task)$.properties,
internal_valid_task_hash)
internal_valid_task_hash,
task$characteristics)
}
2 changes: 1 addition & 1 deletion R/mlr_reflections.R
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ local({
)

### Learner
tmp = c("featureless", "missings", "weights", "importance", "selected_features", "oob_error", "loglik", "hotstart_forward", "hotstart_backward", "validation", "internal_tuning", "marshal")
tmp = c("featureless", "missings", "weights", "importance", "selected_features", "oob_error", "hotstart_forward", "hotstart_backward", "validation", "internal_tuning", "marshal")
mlr_reflections$learner_properties = list(
classif = c(tmp, "twoclass", "multiclass"),
regr = tmp
Expand Down
2 changes: 1 addition & 1 deletion man/BenchmarkResult.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions man/Learner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/Task.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions tests/testthat/_snaps/Task.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# $characteristics works

Code
task
Output
<TaskClassif:spam> (4601 x 58): HP Spam Detection
* Target: type
* Properties: twoclass
* Features (57):
- dbl (57): address, addresses, all, business, capitalAve,
capitalLong, capitalTotal, charDollar, charExclamation, charHash,
charRoundbracket, charSemicolon, charSquarebracket, conference,
credit, cs, data, direct, edu, email, font, free, george, hp, hpl,
internet, lab, labs, mail, make, meeting, money, num000, num1999,
num3d, num415, num650, num85, num857, order, original, our, over,
parts, people, pm, project, re, receive, remove, report, table,
technology, telnet, will, you, your
* Characteristics: foo=1, bar=a

61 changes: 61 additions & 0 deletions tests/testthat/test_Task.R
Original file line number Diff line number Diff line change
Expand Up @@ -668,3 +668,64 @@ test_that("$select changes hash", {
h2 = task$hash
expect_false(h1 == h2)
})

test_that("$characteristics works", {
task = tsk("spam")
characteristics = list(foo = 1, bar = "a")
task$characteristics = characteristics

expect_snapshot(task)
expect_equal(task$characteristics, characteristics)

tsk_1 = tsk("spam")
tsk_1$characteristics = list(n = 300)
tsk_2 = tsk("spam")
tsk_2$characteristics = list(n = 200)

expect_true(tsk_1$hash != tsk_2$hash)

learner = lrn("classif.rpart")
resampling = rsmp("cv", folds = 3)

design = benchmark_grid(
tasks = list(tsk_1, tsk_2),
learners = learner,
resamplings = resampling
)

bmr = benchmark(design)
tab = as.data.table(bmr, task_characteristics = TRUE)
expect_names(names(tab), must.include = "n")
expect_subset(tab$n, c(300, 200))

tsk_1$characteristics = list(n = 300, f = 3)
tsk_2$characteristics = list(n = 200, f = 2)

design = benchmark_grid(
tasks = list(tsk_1, tsk_2),
learners = learner,
resamplings = resampling
)

bmr = benchmark(design)
tab = as.data.table(bmr, task_characteristics = TRUE)
expect_names(names(tab), must.include = c("n", "f"))
expect_subset(tab$n, c(300, 200))
expect_subset(tab$f, c(2, 3))

tsk_1$characteristics = list(n = 300, f = 2)
tsk_2$characteristics = list(n = 200)

design = benchmark_grid(
tasks = list(tsk_1, tsk_2),
learners = learner,
resamplings = resampling
)

bmr = benchmark(design)
tab = as.data.table(bmr, task_characteristics = TRUE)

expect_names(names(tab), must.include = c("n", "f"))
expect_subset(tab$n, c(300, 200))
expect_subset(tab$f, c(2, NA_real_))
})
5 changes: 5 additions & 0 deletions tests/testthat/test_resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -520,3 +520,8 @@ test_that("resampling instantiated on a different task throws an error", {
expect_error(resample(tsk("pima"), lrn("classif.rpart"), resampling), "The resampling was probably instantiated on a different task")

})

test_that("$score() checks for models", {
rr = resample(tsk("mtcars"), lrn("regr.debug"), rsmp("holdout"))
expect_error(rr$score(msr("aic")), "requires the trained model")
})

0 comments on commit e01119d

Please sign in to comment.