Skip to content

Commit

Permalink
get_predicted() with categorical: extra Row column
Browse files Browse the repository at this point in the history
Fixes #989
  • Loading branch information
strengejacke committed Jan 9, 2025
1 parent eff47f4 commit 3485f0a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
11 changes: 11 additions & 0 deletions R/get_predicted.R
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,7 @@ get_predicted.phylolm <- function(x,
.get_predicted_centrality_from_draws <- function(x,
iter,
centrality_function = base::mean,
datagrid = NULL,
...) {
# outcome: ordinal/multinomial/multivariate produce a 3D array of predictions,
# which we stack in "long" format
Expand All @@ -853,6 +854,16 @@ get_predicted.phylolm <- function(x,
Predicted = apply(iter_stacked, 1, centrality_function),
stringsAsFactors = FALSE
)
# for ordinal etc. outcomes, we need to include the data from the grid, too
if (!is.null(datagrid)) {
# due to reshaping predictions into long format, we to repeat the
# datagrid multiple times, to have same number of rows
times <- nrow(predictions) / nrow(datagrid)
if (nrow(predictions) %% times == 0) {
datagrid <- do.call(rbind, replicate(times, datagrid, simplify = FALSE))
predictions <- cbind(predictions[1:2], datagrid, predictions[3])
}
}
iter <- as.data.frame(iter_stacked)
# outcome with a single level
} else {
Expand Down
7 changes: 6 additions & 1 deletion R/get_predicted_bayesian.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,12 @@ get_predicted.stanreg <- function(x,
}

# Get predictions (summarize)
predictions <- .get_predicted_centrality_from_draws(x, iter = draws, ...)
predictions <- .get_predicted_centrality_from_draws(
x,
iter = draws,
datagrid = my_args$data,
...
)

# Output
ci_data <- get_predicted_ci(
Expand Down

0 comments on commit 3485f0a

Please sign in to comment.