Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualization: Partial residual plots #129

Open
mattansb opened this issue Jul 6, 2021 · 15 comments
Open

Visualization: Partial residual plots #129

mattansb opened this issue Jul 6, 2021 · 15 comments
Labels
Feature idea 🔥 New feature or request Plot 🎇 Something related to plotting

Comments

@mattansb
Copy link
Member

mattansb commented Jul 6, 2021

Are we allowed to steal from ourselves?

https://strengejacke.github.io/ggeffects/reference/residualize_over_grid.html

@DominiqueMakowski
Copy link
Member

library(ggeffects)
#> Warning: package 'ggeffects' was built under R version 4.0.5
set.seed(1234)
x <- rnorm(200)
z <- rnorm(200)
# quadratic relationship
y <- 2 * x + x^2 + 4 * z + rnorm(200)

d <- data.frame(x, y, z)
model <- lm(y ~ x + z, data = d)


pr <- ggpredict(model, c("x [all]", "z"))
head(residualize_over_grid(pr, model))
#>          x group predicted
#> 53  -1.207  0.07 -1.797239
#> 402  0.277  1.08  4.888712
#> 518  1.084  0.07  3.232202
#> 9   -2.346  1.08  4.133561
#> 428  0.429  0.07  1.801594
#> 441  0.506  1.08  5.659527

Created on 2021-07-06 by the reprex package (v2.0.0)

I don't understand what does it do

@mattansb
Copy link
Member Author

mattansb commented Jul 6, 2021

Here’s a basic explanation: https://en.wikipedia.org/wiki/Partial_residual_plot
(I’m sure @bwiernik would looooveee to elaborate!)

S <- diag(1, 3, 3)
S[1,2] <- S[2,1] <- 0.6
S[1,3] <- S[3,1] <- 0.8
S[2,3] <- S[3,2] <- 0.8

dat <- MASS::mvrnorm(500, c(10,20,30), S, empirical = TRUE)
colnames(dat) <- c("X","Y","Z")
dat <- data.frame(dat)

fit <- lm(Y ~ X + Z, data = dat)

What the hell is this???

plot(ggeffects::ggpredict(fit, "X"), add.data = TRUE, jitter = 0)
#> Loading required namespace: ggplot2

This is because the raw data can only show us the marginal association.
Instead, we can look at the residualized data (where the covariance with Z is “controlled” for):

plot(ggeffects::ggpredict(fit, "X [all]"), residuals = TRUE, jitter = 0)

Created on 2021-07-06 by the reprex package (v2.0.0)

@DominiqueMakowski
Copy link
Member

Oh I see, so that'd be mostly for plotting of the data points overlay right?

@mattansb
Copy link
Member Author

mattansb commented Jul 6, 2021

Yes, that would be only for data overlay.

@bwiernik
Copy link
Contributor

bwiernik commented Jul 6, 2021

(I’m sure @bwiernik would looooveee to elaborate!)

👀 😝

So, we might consider making four types of plots for a regression model relating a single predictor to the response variable. Two of them are confusingly named something with "partial" in the title.

See them here below:

library(ggplot2)

mf <- lm(mpg ~ hp + wt, data = transform(mtcars, cyl = factor(cyl)))
mr <- lm(mpg ~ hp,      data = transform(mtcars, cyl = factor(cyl)))
mx <- lm( wt ~ hp,      data = transform(mtcars, cyl = factor(cyl)))

plot_dat <- mtcars |> 
  subset(select = c(mpg, cyl, hp, wt)) |> 
  transform(
    fit_f = predict(mf),
    res_f =   resid(mf),
    fit_r = predict(mr),
    res_r =   resid(mr),
    fit_x = predict(mx),
    res_x =   resid(mx),
    par_x = coef(mf)['wt'] * wt,
    cmf_y = as.vector(colMeans(model.matrix(mf)[,names(coef(mf)) != "wt"]) %*% 
                        coef(mf)[names(coef(mf)) != "wt"])
  )

base_plot <- ggplot(plot_dat) +
  geom_point() +
  geom_smooth(method = "lm", formula = y ~ x) +
  see::theme_modern(plot.title.space = 1)


fit_resid_plot <- base_plot +
  aes(x = wt, y = res_f) +
  labs(title = "(A) Predictor - Residuals Plot",
       caption = "Is there misspecified nonlinearity/endogeneity?",
       x = "X",
       y = "Residual Y | full")

par_regre_plot <- base_plot + 
  aes(x = res_x, y = res_r) +
  labs(title = "(B) Residual Predictor - Residuals Plot", 
       subtitle = "Partial Regression Plot / Added Variable Plot",
       caption = "Does X add predictive power above and beyond covariates?",
       x = "Residual X | covariates",
       y = "Residual Y | covariates")

fit_effect_plot <- ggplot(plot_dat) +
  geom_point(aes(x = wt, y = mpg)) +
  geom_smooth(
    aes(x = wt, y = res_f + par_x + cmf_y),
    method = "lm", formula = y ~ x
  ) +
  see::theme_modern(plot.title.space = 1) +
  labs(title = "(C) Predictor - Response Plot",
       subtitle = "Predictor effects plot",
       caption = "What is the relationship of X with Y, controlling for covariates?",
       x = "X",
       y = "Y")

par_resid_plot <- base_plot +
  aes(x = wt, y = res_f + par_x) +
  labs(title = "(D) Predictor - Partial Residuals Plot",
       caption = "What is the relationship of X with Y, controlling for covariates?",
       x = "X",
       y = "Residual Y | full  + Effect_X")
fit_resid_plot

par_regre_plot

fit_effect_plot

par_resid_plot

  1. (A) Predictor vs Residuals (full model): Most common diagnostic plot--detect nonlinearity, omitted variables, omitted interactions, etc.
  2. (B) Residual predictor vs residual response (from covariates-only model): "added variable plot"--how does the unique part of X predict the left-over unique part of Y after the covariates?
  3. (C) Predictor vs Response plot: "predictor effects plot" -- points are raw response, fit line is partial slope for X controlling for covariates (computed as residual + x effect (X * βx + covariate mean effects)-- used to illustrate the partial effect of X on Y
  4. (D) Predictor vs partial residuals plot: points computed as residual + e effect; fit line is lm for that. This is used to illustrate partial effect of X (like C; but more confusingly). Can also be used for misspecification diagnostics by adding a loess line and seeing if there is nonlinearity/interactions missing from the model (see https://cran.r-project.org/web/packages/effects/vignettes/partial-residuals.pdf)

@DominiqueMakowski
Copy link
Member

I have a hard time understanding what should be a separate function, what should be a plotting option and what should be an option to estimate_predicted ^^ where should we start?

Also, I'm thinking we should add an argument to datawizard::adjust() to be able to pass a pre-specified model to adjust with, rather than letting it build the model.

@bwiernik
Copy link
Contributor

bwiernik commented Jul 7, 2021

I don't think overloading adjust is a good idea. Instead, I think separate functions make sense here.

@mattansb
Copy link
Member Author

mattansb commented Jul 7, 2021

(like C; but more confusingly)

Obviously I completely disagree 😅 - D is the plot that corresponds the closest to the slope's coefficient - it gives the slope (as the conditional regression line) + it visualizes the errors from the overall regression hyper-plane across that line/dimension. That is, for each data point, you can see it's predicted value (on the reg line) and it's error. And it can be used for diagnostics - linearity, missing interactions... A thing of beauty!

@DominiqueMakowski
Copy link
Member

so plot(estimate_predicted) should be able to produce either C (default - as is implemented now) or D? i.e., we should try implementing D?

@mattansb
Copy link
Member Author

mattansb commented Jul 7, 2021

I think so, yes.
I can prob get the code I wrote for ggeffects to work here too (the one that takes data + grid = residualized).

@bwiernik
Copy link
Contributor

bwiernik commented Jul 7, 2021

Hmm, I wonder if a separate estimate_partial() function would make more sense? With the option for either B or D above? To do that, it could have partial.x and partial.y arguments.

I admit I don't really get D at all--I'm not really following what you are saying above Mattan? I mostly use A and B for diagnostics/probing

@mattansb
Copy link
Member Author

mattansb commented Jul 7, 2021

Take the data from data from plot A - each point is (X, residual).
Now take the regression line, and along X, add it residual to the line, and you have plot D.
This is why D has the properties of plot A, and non of the weirdness going on on the x-axis of plot B.
And if you draw a vertical line from each point (X, pred Y + residual), you can see what the predicted value for that point would be if all other vars were held constant. This is why it's (IMO) the closest visual representation of the data around the regression hyper-plane - if you were to "flatten" all other (non X or Y) of the multi-variable hyper space to a point (fix them to a constant), the resulting X,Y plot would be D.

I just find it an elegant* way to present results when there are multiple predictors involved 🤷‍♂️
[*It's also computationally easy, compared to B that required extra model fitting]

@bwiernik
Copy link
Contributor

bwiernik commented Jul 7, 2021

(B is a completely different thing than A--not intended for the same purpose--so let's not worry about that comparison)

@DominiqueMakowski
Copy link
Member

B is indeed different in the sense that it requires multiple models, it's a thing on its own. A is pretty much obtained via performance:check if I'm not mistaken. C is the current modelbased default. Which leaves D:

For the following model, how would you do to get the partial residuals?

m <- lm(Sepal.Length ~ Petal.Width + Species, data = iris)

pred <- modelbased::estimate_expectation(m, target = "Petal.Width")
pred
#> Model-based Expectation
#> 
#> Petal.Width | Species | Predicted |   SE |       95% CI
#> -------------------------------------------------------
#> 0.10        |  setosa |      4.87 | 0.07 | [4.73, 5.02]
#> 0.37        |  setosa |      5.12 | 0.07 | [4.97, 5.26]
#> 0.63        |  setosa |      5.36 | 0.10 | [5.16, 5.56]
#> 0.90        |  setosa |      5.61 | 0.14 | [5.32, 5.89]
#> 1.17        |  setosa |      5.85 | 0.19 | [5.47, 6.23]
#> 1.43        |  setosa |      6.09 | 0.24 | [5.62, 6.57]
#> 1.70        |  setosa |      6.34 | 0.29 | [5.77, 6.91]
#> 1.97        |  setosa |      6.58 | 0.34 | [5.91, 7.26]
#> 2.23        |  setosa |      6.83 | 0.39 | [6.06, 7.60]
#> 2.50        |  setosa |      7.07 | 0.44 | [6.20, 7.95]
#> 
#> Variable predicted: Sepal.Length
#> Predictors modulated: Petal.Width
#> Predictors controlled: Species

plot(pred, show_data = "none")

Created on 2021-07-09 by the reprex package (v2.0.0)

@mattansb
Copy link
Member Author

mattansb commented Jul 9, 2021

Okay, this is the code for this issue and for #130.

***The functions***
get_data_for_grid <- function(grid, residuals = FALSE, collapse.by = NULL) {
  #' @param grid For modelbased
  #' @param residuals [FALSE (default) | TRUE] Should the partial residuals be
  #'   returned?
  #' @param collapse.by [NULL (default) | TRUE | char] The name of a random
  #'   grouping factor to collapse by. If TRUE will select (the first) from the
  #'   model.
  model <- attr(grid, "model")
  data <- insight::get_data(model)
  pred_name <- attr(grid, "response", exact = TRUE)
  X_names <- attr(grid, "target", exact = TRUE)
  data
  
  if (residuals) {
    data_r <- .residualize_over_grid(
      grid = grid[union("Predicted", X_names)],
      model = model,
      data = data,
      pred_name = "Predicted",
      collapse.by = collapse.by
    )
    colnames(data_r)[colnames(data_r)=="Predicted"] <- pred_name
    if(is.null(collapse.by)) {
      add <- setdiff(colnames(data), colnames(data_r))
      data_r[add] <- data[add] 
    }
    data <- data_r[intersect(colnames(data), colnames(data_r))]
  } else if (!is.null(collapse.by)) {
    data_r <- .collapse_by_group(
      data = data,
      model = model,
      pred_name = pred_name,
      X_names = X_names,
      collapse.by = collapse.by
    )
    data <- data_r[intersect(colnames(data), colnames(data_r))]
  }
  
  return(as.data.frame(data))
}

.residualize_over_grid <- function (grid, model, data, pred_name, collapse.by = NULL, ...) {
  old_d <- insight::get_predictors(model)
  fun_link <- insight::link_function(model)
  inv_fun <- insight::link_inverse(model)
  predicted <- grid[[pred_name]]
  grid[[pred_name]] <- NULL
  is_fixed <- sapply(grid, function(x) length(unique(x))) == 1
  grid <- grid[, !is_fixed, drop = FALSE]
  old_d <- old_d[, colnames(grid)[colnames(grid) %in% colnames(old_d)], drop = FALSE]
  if (!.is_grid(grid)) {
    stop("Grid for partial residuals must be a fully crossed grid.")
  }
  
  best_match <- NULL
  for (p in colnames(old_d)) {
    if (is.factor(old_d[[p]]) || is.logical(old_d[[p]]) || 
        is.character(old_d[[p]])) {
      grid[[p]] <- as.character(grid[[p]])
      old_d[[p]] <- as.character(old_d[[p]])
    }
    else {
      grid[[p]] <- .validate_num(grid[[p]])
    }
    best_match <- .closest(old_d[[p]], grid[[p]], best_match = best_match)
  }
  idx <- apply(best_match, 2, which)
  idx <- sapply(idx, "[", 1)
  res <- tryCatch(stats::residuals(model, type = "working"), 
                  error = function(e) NULL)
  if (is.null(res)) {
    warning("Could not extract residuals.", call. = FALSE)
    return(NULL)
  }
  points <- grid[idx, , drop = FALSE]
  points[[pred_name]] <- fun_link(predicted[idx]) + res
  if (!is.null(collapse.by)) {
    data[[pred_name]] <- points[[pred_name]]
    points <- .collapse_by_group(
      data = data,
      model = model, 
      pred_name = pred_name,
      X_names = colnames(grid), 
      collapse.by = collapse.by
    )
  }
  points[[pred_name]] <- inv_fun(points[[pred_name]])
  points
}


.collapse_by_group <- function (data, model, pred_name, X_names, collapse.by = TRUE) {
  if (!insight::is_mixed_model(model)) {
    stop("This function only works with mixed effects models.", 
         call. = FALSE)
  }
  
  if (isTRUE(collapse.by)) {
    collapse.by <- insight::find_random(model, flatten = TRUE)
  }
  
  if (length(collapse.by) > 1) {
    collapse.by <- collapse.by[1]
    warning("More than one random grouping variable found.", 
            "\n  Using `", collapse.by, "`.", call. = FALSE)
  }
  
  if (!collapse.by %in% colnames(data)) {
    stop("Could not find `", collapse.by, "` column.", call. = FALSE)
  }
  
  agg_data <- stats::aggregate(data[[pred_name]], 
                               by = data[union(collapse.by, X_names)],
                               FUN = mean)
  
  colnames(agg_data)[ncol(agg_data)] <- pred_name
  agg_data
}


# Utils -------------------------------------------------------------------

.is_grid <- function (df) {
  unq <- lapply(df, unique)
  if (prod(sapply(unq, length)) != nrow(df)) {
    return(FALSE)
  }
  df2 <- do.call(expand.grid, args = unq)
  df2$..1 <- 1
  res <- merge(df, df2, by = colnames(df), all = TRUE)
  return(sum(res$..1) == sum(df2$..1))
}

.validate_num <- function(x) {
  if (!is.numeric(x)) {
    x <- as.numeric(as.character(x))
  }
  x
}

.closest <- function (x, target, best_match) {
  if (is.numeric(x)) {
    AD <- abs(outer(x, target, FUN = `-`))
    idx <- apply(AD, 1, function(x) x == min(x))
  }
  else {
    idx <- t(outer(x, target, FUN = `==`))
  }
  if (is.matrix(best_match)) {
    idx <- idx & best_match
  }
  idx
}
library(modelbased)

S <- diag(1, 3, 3)
S[1,2] <- S[2,1] <- 0.6
S[1,3] <- S[3,1] <- 0.8
S[2,3] <- S[3,2] <- 0.8

dat <- MASS::mvrnorm(500, c(10, 20, 30), S, empirical = TRUE)
colnames(dat) <- c("X", "Y", "Z")
dat <- data.frame(dat)
dat$X <- cut(dat$X, breaks = 10)
dat$ID <- rep(letters, length.out = 500)


fit <- lme4::lmer(Y ~ X + Z + (1|ID), data = dat)

pred <- modelbased::estimate_expectation(fit, target = "X")

get_data_for_grid(pred) |> head()
#>          Y           X        Z ID
#> 1 21.34821 (10.9,11.8] 31.57896  a
#> 2 19.20351 (9.22,10.1] 29.67882  b
#> 3 19.42736 (9.22,10.1] 29.84396  c
#> 4 19.07498 (8.37,9.22] 28.65095  d
#> 5 19.90768 (9.22,10.1] 29.47125  e
#> 6 20.33612 (11.8,12.6] 31.29219  f
get_data_for_grid(pred, residuals = T) |> head()
#>            Y           X        Z ID
#> 6   19.98276 (10.9,11.8] 31.57896  a
#> 4   19.48057 (9.22,10.1] 29.67882  b
#> 4.1 19.57408 (9.22,10.1] 29.84396  c
#> 3   20.23562 (8.37,9.22] 28.65095  d
#> 4.2 20.36956 (9.22,10.1] 29.47125  e
#> 7   19.18748 (11.8,12.6] 31.29219  f
get_data_for_grid(pred, collapse.by = T) |> head()
#>          Y           X ID
#> 1 18.47561 (6.67,7.52]  p
#> 2 18.56970 (6.67,7.52]  q
#> 3 18.63299 (6.67,7.52]  t
#> 4 17.10856 (6.67,7.52]  v
#> 5 19.41991 (7.52,8.37]  b
#> 6 19.01859 (7.52,8.37]  e
get_data_for_grid(pred, residuals = T, collapse.by = T) |> head()
#>          Y           X ID
#> 1 20.63882 (6.67,7.52]  p
#> 2 21.43813 (6.67,7.52]  q
#> 3 20.17450 (6.67,7.52]  t
#> 4 18.79272 (6.67,7.52]  v
#> 5 20.14677 (7.52,8.37]  b
#> 6 20.48828 (7.52,8.37]  e

Created on 2021-07-09 by the reprex package (v2.0.0)

@strengejacke strengejacke added Feature idea 🔥 New feature or request Plot 🎇 Something related to plotting labels Jan 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature idea 🔥 New feature or request Plot 🎇 Something related to plotting
Projects
None yet
Development

No branches or pull requests

4 participants