Skip to content

Commit

Permalink
Merge pull request #78 from mayer79/hstats_matrix_methods
Browse files Browse the repository at this point in the history
Add dimnames, dim, and subsetting
  • Loading branch information
mayer79 authored Oct 17, 2023
2 parents 4313a43 + 80cc9b7 commit e8d9437
Show file tree
Hide file tree
Showing 14 changed files with 237 additions and 34 deletions.
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Generated by roxygen2: do not edit by hand

S3method("[",hstats_matrix)
S3method(average_loss,Learner)
S3method(average_loss,default)
S3method(average_loss,explainer)
S3method(average_loss,ranger)
S3method(dim,hstats_matrix)
S3method(dimnames,hstats_matrix)
S3method(h2,default)
S3method(h2,hstats)
S3method(h2_overall,default)
Expand Down
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
- `hstats()`: `n_max` has been increased from 300 to 500 rows. This will make estimates of H statistics more stable at the price of longer run time. Reduce to 300 for the old behaviour.
- `hstats()`: By default, three-way interactions are not calculated anymore. Set `threeway_m` to 5 for the old behaviour.
- Revised plots: The colors and color palettes have changed and can (also) be controlled via global options. For instance, to change the fill color of all bars, set `options(hstats.fill = new value)`. Value labels are more clear, and there are more options. Varying color/fill scales now use viridis (inferno). This can be modified on the fly or via `options(hstats.viridis_args = list(...))`.
- "hstats_matrix" object: All statistics functions, e.g., `h2_pairwise()` or `perm_importance()`, now return a "hstats_matrix". The values are stored in `$M` and can be plotted via `plot()`.
- "hstats_matrix" object: All statistics functions, e.g., `h2_pairwise()` or `perm_importance()`, now return a "hstats_matrix". The values are stored in `$M` and can be plotted via `plot()`. Other methods are: `dimnames()`, `rownames()`, `colnames()`, `dim()`, `nrow()`, `ncol()`, `head()`, `tail()`, and subsetting like a normal matrix. This allows, e.g, to select and plot only one column of the results.
- `perm_importance()`: The `perms` argument has been changed to `m_rep`.
- `print()` and `summary()` methods have been revised.
- The arguments `w` (case weights) and `y` (response) can now also be passed as column *names*.
Expand Down
5 changes: 3 additions & 2 deletions R/H2_pairwise.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@
#' # MODEL 2: Multi-response linear regression
#' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)
#' s <- hstats(fit, X = iris[3:5], verbose = FALSE)
#' h2_pairwise(s)
#' plot(h2_pairwise(s))
#' x <- h2_pairwise(s)
#' plot(x)
#' plot(x[, "Sepal.Length"])
h2_pairwise <- function(object, ...) {
UseMethod("h2_pairwise")
}
Expand Down
1 change: 1 addition & 0 deletions R/perm_importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
#' s
#' plot(s)
#' plot(s, swap_dim = TRUE, top_m = 2)
#' plot(s[, "Sepal.Length"])
perm_importance <- function(object, ...) {
UseMethod("perm_importance")
}
Expand Down
2 changes: 1 addition & 1 deletion R/utils_calculate.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ align_pred <- function(x) {
#' @keywords internal
#'
#' @param x Vector or matrix.
#' @param ngroups Number of groups of fixed length `NROW(x) / ngroups`.
#' @param ngroups Number of groups (`x` was stacked that many times).
#' @param w Optional vector with case weights of length `NROW(x) / ngroups`.
#' @returns A (g x K) matrix, where g is the number of groups, and K = NCOL(x).
wrowmean <- function(x, ngroups = 1L, w = NULL) {
Expand Down
67 changes: 67 additions & 0 deletions R/utils_statistics.R
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,73 @@ print.hstats_matrix <- function(x, top_m = Inf, ...) {
invisible(x)
}

#' Dimensions of "hstats_matrix" Object
#'
#' Implies `nrow()` and `ncol()`.
#'
#' @param x An object of class "hstats_matrix".
#' @returns
#' A numeric vector of length two providing the number of rows and columns
#' of "M" object stored in `x`.
#' @examples
#' fit <- lm(Sepal.Length ~ . + Petal.Width:Species, data = iris)
#' s <- hstats(fit, X = iris[-1])
#' x <- h2_pairwise(s)
#' dim(x)
#' nrow(x)
#' ncol(x)
#' @export
dim.hstats_matrix <- function(x) {
dim(x[["M"]])
}

#' Dimnames of "hstats_matrix" Object
#'
#' Extracts dimnames of the "M" matrix in `x`. Implies `rownames()` and `colnames()`.
#'
#' @param x An object of class "hstats_matrix".
#' @returns Dimnames of the statistics matrix.
#' @examples
#' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)
#' s <- hstats(fit, X = iris[3:5], verbose = FALSE)
#' x <- h2_pairwise(s)
#' dimnames(x)
#' rownames(x)
#' colnames(x)
#' @export
dimnames.hstats_matrix <- function(x) {
dimnames(x[["M"]])
}

#' Subsets "hstats_matrix" Object
#'
#' Use standard square bracket subsetting to select rows and/or columns of
#' statistics "M" (and "SE" in case of permutation importance statistics).
#' Implies `head()` and `tail()`.
#'
#' @param x An object of class "hstats_matrix".
#' @param i Row subsetting.
#' @param j Column subsetting.
#' @param ... Currently unused.
#' @returns A new object of class "hstats_matrix".
#' @examples
#' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)
#' imp <- perm_importance(fit, X = iris, y = c("Sepal.Length", "Sepal.Width"))
#' head(imp, 1)
#' tail(imp, 2)
#' imp[1, "Sepal.Length"]
#' imp[1]
#' imp[, "Sepal.Width"]$SE
#' plot(imp[, "Sepal.Width"])
#' @export
`[.hstats_matrix` <- function(x, i, j, ...) {
x$M <- x$M[i, j, drop = FALSE]
if (!is.null(x$SE)) {
x$SE <- x$SE[i, j, drop = FALSE]
}
x
}

#' Plots "hstats_matrix" Object
#'
#' Plot method for objects of class "hstats_matrix".
Expand Down
34 changes: 20 additions & 14 deletions backlog/benchmark.R
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
library(hstats)
library(iml)
library(DALEX)
library(ingredients)
library(flashlight)

library(shapviz)
library(xgboost)
library(ggplot2)
library(microbenchmark)

# future::plan(multisession, workers = 1)

# Data preparation
colnames(miami) <- tolower(colnames(miami))
miami <- transform(miami, log_price = log(sale_prc))
x <- c("tot_lvg_area", "lnd_sqfoot", "latitude", "longitude",
"structure_quality", "age", "month_sold")
coord <- c("longitude", "latitude")

# Train/valid split
# Modeling
set.seed(1)
ix <- sample(nrow(miami), 0.8 * nrow(miami))
train <- data.frame(miami[ix, ])
Expand Down Expand Up @@ -50,9 +43,12 @@ fit <- xgb.train(
callbacks = list(cb.print.evaluation(period = 100))
)

# Interpret via {hstats}
average_loss(fit, X = X_valid, y = y_valid) # 0.0247 MSE -> 0.157 RMSE

perm_importance(fit, X = X_valid, y = y_valid) |>
plot()

# Or combining some features
v_groups <- list(
coord = c("longitude", "latitude"),
Expand All @@ -61,6 +57,7 @@ v_groups <- list(
)
perm_importance(fit, v = v_groups, X = X_valid, y = y_valid) |>
plot()

H <- hstats(fit, v = x, X = X_valid)
H
plot(H)
Expand All @@ -78,6 +75,8 @@ g <- unique(X_valid[, coord])
pp <- partial_dep(fit, v = coord, X = X_valid, grid = g)
plot(pp, d2_geom = "point", alpha = 0.5, size = 1) +
coord_equal()

# Takes some seconds because it generates the last plot per structure quality
partial_dep(fit, v = coord, X = X_valid, grid = g, BY = "structure_quality") |>
plot(pp, d2_geom = "point", alpha = 0.5) +
coord_equal()
Expand All @@ -86,6 +85,14 @@ partial_dep(fit, v = coord, X = X_valid, grid = g, BY = "structure_quality") |>
# Naive benchmark
#=====================================

library(iml) # Might benefit of multiprocessing, but on Windows with XGB models, this is not easy
library(DALEX)
library(ingredients)
library(flashlight)
library(microbenchmark)

set.seed(1)

# iml
predf <- function(object, newdata) predict(object, data.matrix(newdata[x]))
mod <- Predictor$new(fit, data = as.data.frame(X_valid), y = y_valid,
Expand Down Expand Up @@ -150,17 +157,16 @@ microbenchmark(

# H-Stats -> we use a subset of 500 rows
X_v500 <- X_valid[1:500, ]
mod500 <- Predictor$new(fit, data = as.data.frame(X_v500), y = y_valid[1:500],
predict.function = predf)
mod500 <- Predictor$new(fit, data = as.data.frame(X_v500), predict.function = predf)
fl500 <- flashlight(fl, data = as.data.frame(valid[1:500, ]))

# iml # 77 s (no pairwise possible)
# iml # 90 s (no pairwise possible)
system.time(
iml_overall <- Interaction$new(mod500, grid.size = 500)
)

# flashlight: 12s total, doing only one pairwise calculation, otherwise would take 63s
system.time( # 10s
# flashlight: 14s total, doing only one pairwise calculation, otherwise would take 63s
system.time( # 12s
fl_overall <- light_interaction(fl500, v = x, grid_size = Inf, n_max = Inf)
)
system.time( # 2s
Expand Down
5 changes: 3 additions & 2 deletions man/H2_pairwise.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 26 additions & 0 deletions man/dim.hstats_matrix.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions man/dimnames.hstats_matrix.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/perm_importance.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 35 additions & 0 deletions man/sub-.hstats_matrix.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

52 changes: 51 additions & 1 deletion tests/testthat/test_statistics.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

test_that("postprocess() works for matrix input", {
num <- cbind(a = 1:3, b = c(1, 1, 1))
denom <- cbind(a = 1:3, b = 1:3)
Expand Down Expand Up @@ -31,3 +30,54 @@ test_that("postprocess() works for vector input", {
expect_null(postprocess(num = 0, zero = FALSE))
})


test_that(".zap_small() works for vector input", {
expect_equal(.zap_small(1:3), 1:3)
expect_equal(.zap_small(c(1:3, NA)), c(1:3, 0))
expect_equal(.zap_small(c(0.001, 1), eps = 0.01), c(0, 1))
})

test_that(".zap_small() works for matrix input", {
expect_equal(
.zap_small(cbind(c(0.001, 1), c(0, 0)), eps = 0.01),
cbind(c(0, 1), c(0, 0))
)
})

fit <- lm(cbind(up = uptake, up2 = 2 * uptake) ~ Type * Treatment * conc, data = CO2)
H <- hstats(fit, X = CO2[2:4], verbose = FALSE)
s <- h2_pairwise(H)

test_that("print() method does not give error", {
capture_output(expect_no_error(print(s)))
})

test_that("dim() is correct", {
expect_equal(dim(s), c(3L, 2L))
})

test_that("dimnames() is correct", {
expect_equal(dimnames(s), list(rownames(s$M), colnames(s$M)))
})

test_that("subsetting works", {
expect_equal(dim(s[, "up2"]), c(3L, 1L))
expect_equal(dim(s[1, "up2"]), c(1L, 1L))
expect_equal(dim(s[1:2, ]), c(2L, 2L))
})

fit <- lm(uptake ~ Type * Treatment * conc, data = CO2)
set.seed(1L)
s <- perm_importance(fit, X = CO2[2:4], y = CO2$uptake)

test_that("print() method does not give error", {
capture_output(expect_no_error(print(s)))
})

test_that("dim() is correct", {
expect_equal(dim(s), c(3L, 1L))
})

test_that("rownames() is correct", {
expect_equal(rownames(s[1L, ]), rownames(s$M[1L, , drop = FALSE]))
})
Loading

0 comments on commit e8d9437

Please sign in to comment.