Skip to content

Commit

Permalink
Merge pull request #112 from mayer79/mlr3-simplification
Browse files Browse the repository at this point in the history
Mlr3 simplification
  • Loading branch information
mayer79 authored Dec 26, 2023
2 parents 69e798b + a86cd74 commit fc764de
Show file tree
Hide file tree
Showing 17 changed files with 23 additions and 270 deletions.
8 changes: 1 addition & 7 deletions .github/workflows/test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,16 @@ jobs:
clean = FALSE,
install_path = file.path(Sys.getenv("RUNNER_TEMP"), "package"),
function_exclusions = c(
"partial_dep\\.Learner",
"partial_dep\\.ranger",
"partial_dep\\.explainer",
"ice\\.Learner",
"ice\\.ranger",
"ice\\.explainer",
"hstats\\.Learner",
"hstats\\.ranger",
"hstats\\.explainer",
"perm_importance\\.Learner",
"perm_importance\\.ranger",
"perm_importance\\.explainer",
"average_loss\\.Learner",
"average_loss\\.ranger",
"average_loss\\.explainer",
"mlr3_pred_fun"
"average_loss\\.explainer"
)
)
shell: Rscript {0}
Expand Down
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: hstats
Title: Interaction Statistics
Version: 1.1.1
Version: 1.1.2
Authors@R:
person("Michael", "Mayer", , "[email protected]", role = c("aut", "cre"))
Description: Fast, model-agnostic implementation of different H-statistics
Expand Down
5 changes: 0 additions & 5 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

S3method("[",hstats_matrix)
S3method("dimnames<-",hstats_matrix)
S3method(average_loss,Learner)
S3method(average_loss,default)
S3method(average_loss,explainer)
S3method(average_loss,ranger)
Expand All @@ -16,21 +15,17 @@ S3method(h2_pairwise,default)
S3method(h2_pairwise,hstats)
S3method(h2_threeway,default)
S3method(h2_threeway,hstats)
S3method(hstats,Learner)
S3method(hstats,default)
S3method(hstats,explainer)
S3method(hstats,ranger)
S3method(ice,Learner)
S3method(ice,default)
S3method(ice,explainer)
S3method(ice,ranger)
S3method(partial_dep,Learner)
S3method(partial_dep,default)
S3method(partial_dep,explainer)
S3method(partial_dep,ranger)
S3method(pd_importance,default)
S3method(pd_importance,hstats)
S3method(perm_importance,Learner)
S3method(perm_importance,default)
S3method(perm_importance,explainer)
S3method(perm_importance,ranger)
Expand Down
7 changes: 7 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# hstats 1.1.2

## API

- {mlr3}: Non-probabilistic classification now works.
- {mlr3}: For *probabilistic* classification, you now have to pass `predict_type = "prob"`.

# hstats 1.1.1

## Performance improvements
Expand Down
25 changes: 0 additions & 25 deletions R/average_loss.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,31 +135,6 @@ average_loss.ranger <- function(object, X, y,
)
}

#' @describeIn average_loss Method for "mlr3" models.
#' @export
average_loss.Learner <- function(object, X, y,
pred_fun = NULL,
loss = "squared_error",
agg_cols = FALSE,
BY = NULL, by_size = 4L,
w = NULL, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
average_loss.default(
object = object,
X = X,
y = y,
pred_fun = pred_fun,
loss = loss,
agg_cols = agg_cols,
BY = BY,
by_size = by_size,
w = w,
...
)
}

#' @describeIn average_loss Method for DALEX "explainer".
#' @export
average_loss.explainer <- function(object,
Expand Down
28 changes: 0 additions & 28 deletions R/hstats.R
Original file line number Diff line number Diff line change
Expand Up @@ -300,34 +300,6 @@ hstats.ranger <- function(object, X, v = NULL,
)
}

#' @describeIn hstats Method for "mlr3" models.
#' @export
hstats.Learner <- function(object, X, v = NULL,
pred_fun = NULL,
pairwise_m = 5L, threeway_m = 0L,
approx = FALSE, grid_size = 50L,
n_max = 500L, eps = 1e-10,
w = NULL, verbose = TRUE, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
hstats.default(
object = object,
X = X,
v = v,
pred_fun = pred_fun,
pairwise_m = pairwise_m,
threeway_m = threeway_m,
approx = approx,
grid_size = grid_size,
n_max = n_max,
eps = eps,
w = w,
verbose = verbose,
...
)
}

#' @describeIn hstats Method for DALEX "explainer".
#' @export
hstats.explainer <- function(object, X = object[["data"]],
Expand Down
26 changes: 0 additions & 26 deletions R/ice.R
Original file line number Diff line number Diff line change
Expand Up @@ -173,32 +173,6 @@ ice.ranger <- function(object, v, X,
)
}

#' @describeIn ice Method for "mlr3" models.
#' @export
ice.Learner <- function(object, v, X,
pred_fun = NULL,
BY = NULL, grid = NULL, grid_size = 49L, trim = c(0.01, 0.99),
strategy = c("uniform", "quantile"), na.rm = TRUE,
n_max = 100L, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
ice.default(
object = object,
v = v,
X = X,
pred_fun = pred_fun,
BY = BY,
grid = grid,
grid_size = grid_size,
trim = trim,
strategy = strategy,
na.rm = na.rm,
n_max = n_max,
...
)
}

#' @describeIn ice Method for DALEX "explainer".
#' @export
ice.explainer <- function(object, v = v, X = object[["data"]],
Expand Down
29 changes: 0 additions & 29 deletions R/partial_dep.R
Original file line number Diff line number Diff line change
Expand Up @@ -213,35 +213,6 @@ partial_dep.ranger <- function(object, v, X,
)
}

#' @describeIn partial_dep Method for "mlr3" models.
#' @export
partial_dep.Learner <- function(object, v, X,
pred_fun = NULL,
BY = NULL, by_size = 4L, grid = NULL, grid_size = 49L,
trim = c(0.01, 0.99),
strategy = c("uniform", "quantile"), na.rm = TRUE,
n_max = 1000L, w = NULL, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
partial_dep.default(
object = object,
v = v,
X = X,
pred_fun = pred_fun,
BY = BY,
by_size = by_size,
grid = grid,
grid_size = grid_size,
trim = trim,
strategy = strategy,
na.rm = na.rm,
n_max = n_max,
w = w,
...
)
}

#' @describeIn partial_dep Method for DALEX "explainer".
#' @export
partial_dep.explainer <- function(object, v, X = object[["data"]],
Expand Down
28 changes: 0 additions & 28 deletions R/perm_importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -228,34 +228,6 @@ perm_importance.ranger <- function(object, X, y, v = NULL,
)
}

#' @describeIn perm_importance Method for "mlr3" models.
#' @export
perm_importance.Learner <- function(object, X, y, v = NULL,
pred_fun = NULL,
loss = "squared_error", m_rep = 4L,
agg_cols = FALSE,
normalize = FALSE, n_max = 10000L,
w = NULL, verbose = TRUE, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
perm_importance.default(
object = object,
X = X,
y = y,
v = v,
pred_fun = pred_fun,
loss = loss,
m_rep = m_rep,
agg_cols = agg_cols,
normalize = normalize,
n_max = n_max,
w = w,
verbose = verbose,
...
)
}

#' @describeIn perm_importance Method for DALEX "explainer".
#' @export
perm_importance.explainer <- function(object,
Expand Down
23 changes: 0 additions & 23 deletions R/utils_input.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,26 +114,3 @@ prepare_y <- function(y, X, ohe = FALSE) {
list(y = prepare_pred(y, ohe = ohe), y_names = y_names)
}

#' mlr3 Helper
#'
#' Returns the prediction function of a mlr3 Learner.
#'
#' @noRd
#' @keywords internal
#'
#' @param object Learner object.
#' @param X Dataframe like object.
#'
#' @returns A function.
mlr3_pred_fun <- function(object, X) {
if ("classif" %in% object$task_type) {
# Check if probabilities are available
test_pred <- object$predict_newdata(utils::head(X))
if ("prob" %in% test_pred$predict_types) {
return(function(m, X) m$predict_newdata(X)$prob)
} else {
stop("Set lrn(..., predict_type = 'prob') to allow for probabilistic classification.")
}
}
function(m, X) m$predict_newdata(X)$response
}
18 changes: 13 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ Strongest relative interaction shown as ICE plot.

## Multivariate responses

{hstats} works also with multivariate output, see examples with
{hstats} works also with multivariate output, see examples for probabilistic classification with

- ranger,
- LightGBM, and
Expand Down Expand Up @@ -377,7 +377,9 @@ plot(H, normalize = FALSE, squared = FALSE, facet_scales = "free_y", ncol = 1)

![](man/figures/xgboost.svg)

### (Non-probabilistic) classification works as well
### Non-probabilistic classification

When predictions are factor levels, {hstats} uses internal one-hot-encoding.

```r
library(ranger)
Expand All @@ -404,7 +406,7 @@ partial_dep(fit, v = "Petal.Length", X = train, BY = "Petal.Width") |>

## Meta-learning packages

Here, we provide some working examples for "tidymodels", "caret", and "mlr3".
Here, we provide examples for {tidymodels}, {caret}, and {mlr3}.

### tidymodels

Expand Down Expand Up @@ -478,8 +480,14 @@ fit_rf$train(task_iris)
s <- hstats(fit_rf, X = iris[, -5])
plot(s)

# Permutation importance
perm_importance(fit_rf, X = iris, y = "Species", loss = "mlogloss") |>
# Permutation importance (probabilistic using multi-logloss)
p <- perm_importance(
fit_rf, X = iris, y = "Species", loss = "mlogloss", predict_type = "prob"
)
plot(p)

# Non-probabilistic using classification error
perm_importance(fit_rf, X = iris, y = "Species", loss = "classification_error") |>
plot()
```

Expand Down
16 changes: 0 additions & 16 deletions man/average_loss.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 0 additions & 19 deletions man/hstats.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit fc764de

Please sign in to comment.