Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

csSampling with multilevel models? #5

Open
awcm0n opened this issue Nov 9, 2023 · 5 comments
Open

csSampling with multilevel models? #5

awcm0n opened this issue Nov 9, 2023 · 5 comments
Assignees

Comments

@awcm0n
Copy link

awcm0n commented Nov 9, 2023

I'm interested in using the csSampling package to run multilevel models on complex survey data, but I didn't succeed in fitting a simple random-intercept model. After the Stan model was fit, the process stalled without error message. So my question is: Is there any guidance as to what types of models can and cannot be fit with the csSampling package?

@mrdwill mrdwill self-assigned this Nov 17, 2023
@mrdwill
Copy link
Collaborator

mrdwill commented Nov 17, 2023

Thanks for your patience. Could you provide more information about the model and the data dimensions? We've used this to fit simple random intercept models before. Did you use the brms wrapper or a custom Stan model? If the stan model fit, then issue is the post-processing. There some data type conversions that could be inefficient for large numbers of samples/draws. I'll start looking into it, so any specifics you can provide would greatly help!

@mrdwill
Copy link
Collaborator

mrdwill commented Nov 17, 2023

I think the crux of the issue is here: https://discourse.mc-stan.org/t/as-matrix-for-unconstrained-parameters/11528/2
the current cs_sampling function spends a lot of effort (nested for-loops) converting a list of constrained stan parameters to a matrix of unconstrained parameters - row by row. There may be a work-around to read from a diagnostic file csv output.

@awcm0n
Copy link
Author

awcm0n commented Nov 20, 2023

Thanks for looking into the issue. I created a minimal example of what I'm trying to do. The code below loads a 2-year MEPS longitudinal data file in wide-format that is converted to long. The goal of the analysis is to determine the change in k6sum from 2019 to 2020. Instead of including a person-level fixed effect, as economists are wont to do, I want to include a person-level random intercept, (1|dupersid), in the model. The Stan model fit, but the post-processing appears to be stuck.

if(!require("MEPS")) {
  library(devtools)
  install_github("e-mitchell/meps_r_pkg/MEPS")
}
library(tidyverse)
library(MEPS)
library(janitor)
library(survey)
library(srvyr)
library(csSampling)
library(brms)

# create long data set that contains a person's (dupersid) k6 score in 2019 and 2020 
dat <- read_MEPS(file = "h225") %>% # load panel data from MEPS
  clean_names() %>% 
  dplyr::select(dupersid, varpsu, varstr, lsaqwt, age=age2x, k6sum2, k6sum4) %>% 
  pivot_longer(cols = c(k6sum2, k6sum4), names_to = "k6round", values_to = "k6sum") %>% 
  mutate(year = ifelse(k6round=="k6sum2", 2019, 2020) |> as.factor()) %>% 
  dplyr::select(-k6round) %>% 
  mutate(across(where(is.numeric), \(x) as.numeric(x)))

# I want to analyse respondents 18 years and older. To do so, I calculate the mean
# weight for this subsample

# subset of respondents 18 years and older
dat_stan <- dat %>% 
  filter(age>=18 & !is.na(k6sum) & lsaqwt>0) 

mwgt <- mean(dat_stan$lsaqwt)

# scale weights
dat$wgt <- dat$lsaqwt/mwgt
dat_stan$wgt <- dat_stan$lsaqwt/mwgt

# create the design object
dsgn <- dat %>% 
  as_survey_design(ids = varpsu, strata = varstr, weights = wgt, nest = TRUE) %>% 
  filter(age>=18 & !is.na(k6sum) & lsaqwt>0)

set.seed (12345)
model_formula <- formula("k6sum|weights (wgt) ~ year + (1|dupersid)")

mod.brms <- cs_sampling_brms(svydes = dsgn,
                             brmsmod = brmsformula(model_formula, center = FALSE),
                             data = dat_stan,
                             family = gaussian(),
                             ctrl_stan = list(chains = 1, iter = 2000, warmup = 1000, thin = 1))

@mrdwill
Copy link
Collaborator

mrdwill commented Oct 9, 2024

I think the crux of the issue is here: https://discourse.mc-stan.org/t/as-matrix-for-unconstrained-parameters/11528/2 the current cs_sampling function spends a lot of effort (nested for-loops) converting a list of constrained stan parameters to a matrix of unconstrained parameters - row by row. There may be a work-around to read from a diagnostic file csv output.

So I wrong. I tried using the posterior package https://mc-stan.org/posterior/articles/posterior.html but didn't see any efficiencies. A potentially worse problem was lazy use of rbind instead of pre-allocating a matrix for the parameters. The cs_sampling version in the testing branch should work faster after replacing the rbinds: https://github.com/RyanHornby/csSampling/tree/testing

@mrdwill
Copy link
Collaborator

mrdwill commented Oct 10, 2024

@awcm0n I apologize for the seriously long delay.

Using the example code you provided and the testing-branch - on default - it ran for me in about 8 hours. The stan part was finished in a few minutes. The issue is that there are 5K+ random effects estimated. So even though the model only have 4 global parameters, cs_sampling is going to estimate and adjust all the parameters. The bottleneck is the default adjustment goes through each MCMC draw and estimates the Hessian and then averages it. This is a big matrix 5K by 5K of derivatives. The alternative is to just evaluate it at the posterior mean. They should be equivalent for large sample sizes, but the MCMC average is more stable. Changing the default let this run in about 25min for me. There are now status messages and the slowest part is step (4) where we have to invert these H and J matrices and take their eigen decomp. That's probably 80-90% of the time now.

Here's the updated call. Note the use of the H_estimate argument. The default is "MCMC", anything else will use the posterior mean.

mod.brms_fast <- cs_sampling_brms(svydes = dsgn, brmsmod = brmsformula(model_formula, center = FALSE), data = dat_stan, family = gaussian(), ctrl_stan = list(chains = 1, iter = 2000, warmup = 1000, thin = 1), H_estimate = "PM")

plot(mod.brms_fast, varnames = colnames(mod.brms_fast$adjusted_parms)[1:4])

Rplot

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants