Skip to content

Commit

Permalink
Merge pull request #110 from jmaspons/rbind.mshapviz
Browse files Browse the repository at this point in the history
Add rbind and `+` methods for mshapviz objects
  • Loading branch information
mayer79 authored Oct 18, 2023
2 parents 995c1ec + a5936c3 commit fa75551
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 8 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Generated by roxygen2: do not edit by hand

S3method("+",mshapviz)
S3method("+",shapviz)
S3method("[",shapviz)
S3method("dimnames<-",shapviz)
Expand All @@ -20,6 +21,7 @@ S3method(get_shap_values,mshapviz)
S3method(get_shap_values,shapviz)
S3method(print,mshapviz)
S3method(print,shapviz)
S3method(rbind,mshapviz)
S3method(rbind,shapviz)
S3method(shapviz,H2OBinomialModel)
S3method(shapviz,H2OModel)
Expand Down
51 changes: 47 additions & 4 deletions R/methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -266,25 +266,68 @@ dimnames.shapviz <- function(x) {
)
}

#' Rowbinds Multiple "shapviz" Objects
#' @rdname plus-.shapviz
#' @examples
#' # mshapviz
#' S <- matrix(c(1, -1, -1, 1), ncol = 2, dimnames = list(NULL, c("x", "y")))
#' X <- data.frame(x = c("a", "b"), y = c(100, 10))
#' s1 <- shapviz(S, X, baseline = 4)[1L]
#' s2 <- shapviz(S, X, baseline = 4)[2L]
#' s <- mshapviz(c(shp1 = s1, shp2 = s2))
#' s + s
#'
#' @export
`+.mshapviz` <- function(e1, e2) {
stopifnot(
is.mshapviz(e1),
is.mshapviz(e2),
length(e1) == length(e2),
names(e1) == names(e2),
mapply(function(x, y) ncol(x) == ncol(y), x = e1, y = e2),
mapply(function(x, y) colnames(x) == colnames(y), x = e1, y = e2)
)

shp_list <- mapply(function(x, y) {
x + y
}, x = e1, y = e2, SIMPLIFY = FALSE)

mshapviz(shp_list)
}

#' Rowbinds Multiple "shapviz" or "mshapviz" Objects
#'
#' Rowbinds multiple "shapviz" objects based on the `+` operator.
#'
#' @param ... Any number of "shapviz" objects.
#' @returns A new object of class "shapviz".
#' @param ... Any number of "shapviz" or "mshapviz" objects.
#' @returns A new object of class "shapviz" or "mshapviz".
#' @examples
#' S <- matrix(c(1, -1, -1, 1), ncol = 2, dimnames = list(NULL, c("x", "y")))
#' X <- data.frame(x = c("a", "b"), y = c(100, 10))
#' s1 <- shapviz(S, X, baseline = 4)[1]
#' s2 <- shapviz(S, X, baseline = 4)[2]
#' s <- rbind(s1, s2)
#' s
#' @seealso [shapviz()]
#' @seealso [shapviz()], [mshapviz()]
#' @export
rbind.shapviz <- function(...) {
Reduce(`+`, list(...))
}

#' @rdname rbind.shapviz
#' @examples
#' # mshapviz
#' S <- matrix(c(1, -1, -1, 1), ncol = 2, dimnames = list(NULL, c("x", "y")))
#' X <- data.frame(x = c("a", "b"), y = c(100, 10))
#' s1 <- shapviz(S, X, baseline = 4)[1L]
#' s2 <- shapviz(S, X, baseline = 4)[2L]
#' s <- mshapviz(c(shp1 = s1, shp2 = s2))
#' rbind(s, s)
#'
#' @export
rbind.mshapviz <- function(...) {
Reduce(`+`, list(...))
}

#' Concatenates "shapviz" Objects
#'
#' This function combines two or more (usually named) "shapviz" objects
Expand Down
11 changes: 11 additions & 0 deletions man/plus-.shapviz.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 15 additions & 4 deletions man/rbind.shapviz.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions tests/testthat/test-interface.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ test_that("concatenating with + works", {
expect_equal((shp + shp)$baseline, shp$baseline)
expect_equal(dim((shp + shp + shp)$S), c(6L, 2L))
expect_equal(dim((shp + shp + shp)$X), c(6L, 2L))

# mshapviz
mapply(function(x, dims) {
expect_equal(dim(x$S), dims)
expect_equal(dim(x$X), dims)
}, x = mshp + mshp, dims = list(c(4L, 2L), c(8L, 2L)))
mapply(function(x, x_sum) expect_equal(x$baseline, x_sum$baseline), x = mshp, x_sum = mshp + mshp)
mapply(function(x, dims) {
expect_equal(dim(x$S), dims)
expect_equal(dim(x$X), dims)
}, x = mshp + mshp + mshp, dims = list(c(6L, 2L), c(12L, 2L)))
})

test_that("concatenating with rbind works", {
Expand All @@ -63,6 +74,18 @@ test_that("concatenating with rbind works", {
expect_equal(rbind(shp, shp)$baseline, shp$baseline)
expect_equal(dim(rbind(shp, shp, shp)$S), c(6L, 2L))
expect_equal(dim(rbind(shp, shp, shp)$X), c(6L, 2L))

# mshapviz
mshp_rbind <- rbind(mshp, mshp)
expect_equal(dim(mshp_rbind$shp$S), c(4L, 2L))
expect_equal(dim(mshp_rbind$shp2$S), c(8L, 2L))
expect_equal(dim(mshp_rbind$shp$X), c(4L, 2L))
expect_equal(dim(mshp_rbind$shp2$X), c(8L, 2L))
mapply(function(x, xbind) expect_equal(x$baseline, xbind$baseline), x = mshp, xbind = mshp_rbind)
mapply(function(x, dims) {
expect_equal(dim(x$S), dims)
expect_equal(dim(x$X), dims)
}, x = rbind(mshp, mshp, mshp), dims = list(c(6L, 2L), c(12L, 2L)))
})

test_that("split() works", {
Expand Down

0 comments on commit fa75551

Please sign in to comment.