Skip to content

Commit

Permalink
BREAKING CHANGE: remove $loglik method from all learners (#1207)
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc authored Nov 13, 2024
1 parent 26376c3 commit 7c3c74b
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 15 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

* fix: Quantiles must not ascend with probabilities.
* refactor: Replace `tsk("boston_housing")` with `tsk("california_housing")`.
* BREAKING CHANGE: Remove ``$loglik()`` method from all learners.

# mlr3 0.21.1

Expand Down
3 changes: 0 additions & 3 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,6 @@
#' * `oob_error(...)`: Returns the out-of-bag error of the model as `numeric(1)`.
#' The learner must be tagged with property `"oob_error"`.
#'
#' * `loglik(...)`: Extracts the log-likelihood (c.f. [stats::logLik()]).
#' This can be used in measures like [mlr_measures_aic] or [mlr_measures_bic].
#'
#' * `internal_valid_scores`: Returns the internal validation score(s) of the model as a named `list()`.
#' Only available for [`Learner`]s with the `"validation"` property.
#' If the learner is not trained yet, this returns `NULL`.
Expand Down
12 changes: 7 additions & 5 deletions R/MeasureAIC.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ MeasureAIC = R6Class("MeasureAIC",
private = list(
.score = function(prediction, learner, ...) {
learner = learner$base_learner()
if ("loglik" %nin% learner$properties) {
return(NA_real_)
}

k = self$param_set$values$k %??% 2
return(stats::AIC(learner$loglik(), k = k))

tryCatch({
return(stats::AIC(stats::logLik(learner$model), k = k))
}, error = function(e) {
warningf("Learner '%s' does not support AIC calculation", learner$id)
return(NA_real_)
})
}
)
)
Expand Down
10 changes: 6 additions & 4 deletions R/MeasureBIC.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ MeasureBIC = R6Class("MeasureBIC",
private = list(
.score = function(prediction, learner, ...) {
learner = learner$base_learner()
if ("loglik" %nin% learner$properties) {
return(NA_real_)
}

return(stats::BIC(learner$loglik()))
tryCatch({
return(stats::BIC(stats::logLik(learner$model)))
}, error = function(e) {
warningf("Learner '%s' does not support BIC calculation", learner$id)
return(NA_real_)
})
}
)
)
Expand Down
2 changes: 1 addition & 1 deletion R/mlr_reflections.R
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ local({
)

### Learner
tmp = c("featureless", "missings", "weights", "importance", "selected_features", "oob_error", "loglik", "hotstart_forward", "hotstart_backward", "validation", "internal_tuning", "marshal")
tmp = c("featureless", "missings", "weights", "importance", "selected_features", "oob_error", "hotstart_forward", "hotstart_backward", "validation", "internal_tuning", "marshal")
mlr_reflections$learner_properties = list(
classif = c(tmp, "twoclass", "multiclass"),
regr = tmp
Expand Down
2 changes: 0 additions & 2 deletions man/Learner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 7c3c74b

Please sign in to comment.