Skip to content

Commit

Permalink
fix: measure checks
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Nov 20, 2024
1 parent b95228a commit b4874a3
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 13 deletions.
17 changes: 5 additions & 12 deletions R/Measure.R
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ Measure = R6Class("Measure",
assert_measure(self, task = task, learner = learner, prediction = prediction)
assert_prediction(prediction, null.ok = "requires_no_prediction" %nin% self$properties)

# check should be added to assert_measure()
# except when the checks are superfluous for rr$score() and bmr$score()
# these checks should be added bellow
if ("requires_task" %in% self$properties && is.null(task)) {
stopf("Measure '%s' requires a task", self$id)
}
Expand All @@ -205,21 +208,14 @@ Measure = R6Class("Measure",
stopf("Measure '%s' requires a learner", self$id)
}

if ("requires_model" %in% self$properties && (is.null(learner) || is.null(learner$model))) {
stopf("Measure '%s' requires the trained model", self$id)
}
if ("requires_model" %in% self$properties && is_marshaled_model(learner$model)) {
stopf("Measure '%s' requires the trained model, but model is in marshaled form", self$id)
if (!is_scalar_na(self$task_type) && self$task_type != prediction$task_type) {
stopf("Measure '%s' incompatible with task type '%s'", self$id, prediction$task_type)
}

if ("requires_train_set" %in% self$properties && is.null(train_set)) {
stopf("Measure '%s' requires the train_set", self$id)
}

if (!is_scalar_na(self$task_type) && self$task_type != prediction$task_type) {
stopf("Measure '%s' incompatible with task type '%s'", self$id, prediction$task_type)
}

score_single_measure(self, task, learner, train_set, prediction)
},

Expand Down Expand Up @@ -358,8 +354,6 @@ score_single_measure = function(measure, task, learner, train_set, prediction) {
return(NaN)
}



if (!is_scalar_na(measure$predict_type) && measure$predict_type %nin% prediction$predict_types) {
# TODO lgr$debug()
return(NaN)
Expand All @@ -370,7 +364,6 @@ score_single_measure = function(measure, task, learner, train_set, prediction) {
return(NaN)
}


get_private(measure)$.score(prediction = prediction, task = task, learner = learner, train_set = train_set)
}

Expand Down
2 changes: 1 addition & 1 deletion R/MeasureInternalValidScore.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ MeasureInternalValidScore = R6Class("MeasureInternalValidScore",
super$initialize(
id = select %??% "internal_valid_score",
task_type = NA_character_,
properties = c("na_score", "requires_model", "requires_learner", "requires_no_prediction"),
properties = c("na_score", "requires_learner", "requires_no_prediction"),
predict_sets = NULL,
predict_type = NA_character_,
range = c(-Inf, Inf),
Expand Down
10 changes: 10 additions & 0 deletions R/assertions.R
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ assert_measure = function(measure, task = NULL, learner = NULL, prediction = NUL
assert_class(measure, "Measure", .var.name = .var.name)

if (!is.null(task)) {

if (!is_scalar_na(measure$task_type) && !test_matching_task_type(task$task_type, measure, "measure")) {
stopf("Measure '%s' is not compatible with type '%s' of task '%s'",
measure$id, task$task_type, task$id)
Expand All @@ -221,6 +222,15 @@ assert_measure = function(measure, task = NULL, learner = NULL, prediction = NUL
}

if (!is.null(learner)) {

if ("requires_model" %in% measure$properties && is.null(learner$model)) {
stopf("Measure '%s' requires the trained model", measure$id)
}

if ("requires_model" %in% measure$properties && is_marshaled_model(learner$model)) {
stopf("Measure '%s' requires the trained model, but model is in marshaled form", measure$id)
}

if (!is_scalar_na(measure$task_type) && measure$task_type != learner$task_type) {
stopf("Measure '%s' is not compatible with type '%s' of learner '%s'",
measure$id, learner$task_type, learner$id)
Expand Down
19 changes: 19 additions & 0 deletions tests/testthat/_snaps/Task.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# $characteristics works

Code
task
Output
<TaskClassif:spam> (4601 x 58): HP Spam Detection
* Target: type
* Properties: twoclass
* Features (57):
- dbl (57): address, addresses, all, business, capitalAve,
capitalLong, capitalTotal, charDollar, charExclamation, charHash,
charRoundbracket, charSemicolon, charSquarebracket, conference,
credit, cs, data, direct, edu, email, font, free, george, hp, hpl,
internet, lab, labs, mail, make, meeting, money, num000, num1999,
num3d, num415, num650, num85, num857, order, original, our, over,
parts, people, pm, project, re, receive, remove, report, table,
technology, telnet, will, you, your
* Characteristics: foo=1, bar=a

5 changes: 5 additions & 0 deletions tests/testthat/test_resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -520,3 +520,8 @@ test_that("resampling instantiated on a different task throws an error", {
expect_error(resample(tsk("pima"), lrn("classif.rpart"), resampling), "The resampling was probably instantiated on a different task")

})

test_that("$score() checks for models", {
rr = resample(tsk("mtcars"), lrn("regr.debug"), rsmp("holdout"))
expect_error(rr$score(msr("aic")), "requires the trained model")
})

0 comments on commit b4874a3

Please sign in to comment.