Skip to content

Commit

Permalink
Merge pull request #82 from mayer79/fix_args
Browse files Browse the repository at this point in the history
Fix arguments of average_loss()
  • Loading branch information
mayer79 authored Oct 19, 2023
2 parents 1dcdfab + 29297ef commit 7a530da
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 43 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
- `average_loss()` also returns a "hstats_matrix" object with `print()` and `plot()` method. The values can be extracted via `$M`.
- The default `v` of `hstats()` and `perm_importance()` is now `NULL`. Internally, it is set to `colnames(X)` (minus the column names of `w` and `y` if passed as name).
- Missing grid values: `partial_dep()` and `ice()` have received a `na.rm` argument that controls if missing values are dropped during grid creation. The default `TRUE` is compatible with earlier releases.
- The position of some function arguments have changed.

# hstats 0.3.0

Expand Down
25 changes: 17 additions & 8 deletions R/average_loss.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,19 @@ average_loss.default <- function(object, X, y,
#' @export
average_loss.ranger <- function(object, X, y,
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
loss = "squared_error",
loss = "squared_error",
agg_cols = FALSE,
BY = NULL, by_size = 4L,
w = NULL, ...) {
average_loss.default(
object = object,
X = X,
y = y,
pred_fun = pred_fun,
BY = BY,
pred_fun = pred_fun,
loss = loss,
agg_cols = agg_cols,
BY = BY,
by_size = by_size,
w = w,
...
)
Expand All @@ -134,7 +137,8 @@ average_loss.ranger <- function(object, X, y,
#' @export
average_loss.Learner <- function(object, X, y,
pred_fun = NULL,
loss = "squared_error",
loss = "squared_error",
agg_cols = FALSE,
BY = NULL, by_size = 4L,
w = NULL, ...) {
if (is.null(pred_fun)) {
Expand All @@ -145,8 +149,10 @@ average_loss.Learner <- function(object, X, y,
X = X,
y = y,
pred_fun = pred_fun,
BY = BY,
loss = loss,
loss = loss,
agg_cols = agg_cols,
BY = BY,
by_size = by_size,
w = w,
...
)
Expand All @@ -158,7 +164,8 @@ average_loss.explainer <- function(object,
X = object[["data"]],
y = object[["y"]],
pred_fun = object[["predict_function"]],
loss = "squared_error",
loss = "squared_error",
agg_cols = FALSE,
BY = NULL,
by_size = 4L,
w = object[["weights"]],
Expand All @@ -168,8 +175,10 @@ average_loss.explainer <- function(object,
X = X,
y = y,
pred_fun = pred_fun,
BY = BY,
loss = loss,
agg_cols = agg_cols,
BY = BY,
by_size = by_size,
w = w,
...
)
Expand Down
40 changes: 22 additions & 18 deletions R/hstats.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@
#' (such as `type = "response"` in a GLM, or `reshape = TRUE` in a multiclass XGBoost
#' model) can be passed via `...`. The default, [stats::predict()], will work in
#' most cases.
#' @param n_max If `X` has more than `n_max` rows, a random sample of `n_max` rows is
#' selected from `X`. In this case, set a random seed for reproducibility.
#' @param w Optional vector of case weights. Can also be a column name of `X`.
#' @param pairwise_m Number of features for which pairwise statistics are to be
#' calculated. The features are selected based on Friedman and Popescu's overall
#' interaction strength \eqn{H^2_j}. Set to to 0 to avoid pairwise calculations.
Expand All @@ -50,6 +47,9 @@
#' speed-up for dense features, mainly for one-way statistics.
#' Note that the quantiles are calculated after subsampling to `n_max` rows.
#' @param eps Threshold below which numerator values are set to 0. Default is 1e-10.
#' @param n_max If `X` has more than `n_max` rows, a random sample of `n_max` rows is
#' selected from `X`. In this case, set a random seed for reproducibility.
#' @param w Optional vector of case weights. Can also be a column name of `X`.
#' @param verbose Should a progress bar be shown? The default is `TRUE`.
#' @param ... Additional arguments passed to `pred_fun(object, X, ...)`,
#' for instance `type = "response"` in a [glm()] model, or `reshape = TRUE` in a
Expand Down Expand Up @@ -139,9 +139,10 @@ hstats <- function(object, ...) {
#' @describeIn hstats Default hstats method.
#' @export
hstats.default <- function(object, X, v = NULL,
pred_fun = stats::predict, n_max = 500L,
w = NULL, pairwise_m = 5L, threeway_m = 0L,
quant_approx = NULL, eps = 1e-10, verbose = TRUE, ...) {
pred_fun = stats::predict,
pairwise_m = 5L, threeway_m = 0L,
quant_approx = NULL, eps = 1e-10,
n_max = 500L, w = NULL, verbose = TRUE, ...) {
stopifnot(
is.matrix(X) || is.data.frame(X),
is.function(pred_fun)
Expand Down Expand Up @@ -275,19 +276,20 @@ hstats.default <- function(object, X, v = NULL,
#' @export
hstats.ranger <- function(object, X, v = NULL,
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
n_max = 500L, w = NULL, pairwise_m = 5L, threeway_m = 0L,
quant_approx = NULL, eps = 1e-10, verbose = TRUE, ...) {
pairwise_m = 5L, threeway_m = 0L,
quant_approx = NULL, eps = 1e-10,
n_max = 500L, w = NULL, verbose = TRUE, ...) {
hstats.default(
object = object,
X = X,
v = v,
pred_fun = pred_fun,
n_max = n_max,
w = w,
pairwise_m = pairwise_m,
threeway_m = threeway_m,
quant_approx = quant_approx,
eps = eps,
n_max = n_max,
w = w,
verbose = verbose,
...
)
Expand All @@ -297,8 +299,9 @@ hstats.ranger <- function(object, X, v = NULL,
#' @export
hstats.Learner <- function(object, X, v = NULL,
pred_fun = NULL,
n_max = 500L, w = NULL, pairwise_m = 5L, threeway_m = 0L,
quant_approx = NULL, eps = 1e-10, verbose = TRUE, ...) {
pairwise_m = 5L, threeway_m = 0L,
quant_approx = NULL, eps = 1e-10,
n_max = 500L, w = NULL, verbose = TRUE, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
Expand All @@ -307,12 +310,12 @@ hstats.Learner <- function(object, X, v = NULL,
X = X,
v = v,
pred_fun = pred_fun,
n_max = n_max,
w = w,
pairwise_m = pairwise_m,
threeway_m = threeway_m,
quant_approx = quant_approx,
eps = eps,
n_max = n_max,
w = w,
verbose = verbose,
...
)
Expand All @@ -323,20 +326,21 @@ hstats.Learner <- function(object, X, v = NULL,
hstats.explainer <- function(object, X = object[["data"]],
v = NULL,
pred_fun = object[["predict_function"]],
n_max = 500L, w = object[["weights"]],
pairwise_m = 5L, threeway_m = 0L,
quant_approx = NULL, eps = 1e-10, verbose = TRUE, ...) {
quant_approx = NULL, eps = 1e-10,
n_max = 500L, w = object[["weights"]],
verbose = TRUE, ...) {
hstats.default(
object = object[["model"]],
X = X,
v = v,
pred_fun = pred_fun,
n_max = n_max,
w = w,
pairwise_m = pairwise_m,
threeway_m = threeway_m,
quant_approx = quant_approx,
eps = eps,
n_max = n_max,
w = w,
verbose = verbose,
...
)
Expand Down
4 changes: 2 additions & 2 deletions R/ice.R
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ ice.ranger <- function(object, v, X,
BY = NULL, grid = NULL, grid_size = 49L,
trim = c(0.01, 0.99),
strategy = c("uniform", "quantile"), na.rm = TRUE,
n_max = 100, ...) {
n_max = 100L, ...) {
ice.default(
object = object,
v = v,
Expand Down Expand Up @@ -194,7 +194,7 @@ ice.explainer <- function(object, v = v, X = object[["data"]],
BY = NULL, grid = NULL, grid_size = 49L,
trim = c(0.01, 0.99),
strategy = c("uniform", "quantile"), na.rm = TRUE,
n_max = 100, ...) {
n_max = 100L, ...) {
ice.default(
object = object[["model"]],
v = v,
Expand Down
3 changes: 3 additions & 0 deletions man/average_loss.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 13 additions & 13 deletions man/hstats.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/ice.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 7a530da

Please sign in to comment.