Skip to content

Commit

Permalink
Remove seed argument (#431)
Browse files Browse the repository at this point in the history
Closes #415

---------

Signed-off-by: Craig Gower-Page <[email protected]>
Co-authored-by: Isaac Gravestock <[email protected]>
  • Loading branch information
gowerc and gravesti authored Oct 4, 2024
1 parent 06ca118 commit c6d0b82
Show file tree
Hide file tree
Showing 12 changed files with 147 additions and 181 deletions.
4 changes: 0 additions & 4 deletions R/draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -478,10 +478,6 @@ extract_data_nmar_as_na <- function(longdata) {
#' @export
draws.bayes <- function(data, data_ice = NULL, vars, method, ncores = 1, quiet = FALSE) {

if (!is.na(method$seed)) {
set.seed(method$seed)
}

longdata <- longDataConstructor$new(data, vars)
longdata$set_strategies(data_ice)

Expand Down
9 changes: 1 addition & 8 deletions R/mcmc.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ fit_mcmc <- function(

n_imputations <- method$n_samples
burn_in <- method$burn_in
seed <- method$seed
burn_between <- method$burn_between
same_cov <- method$same_cov

Expand Down Expand Up @@ -114,13 +113,7 @@ fit_mcmc <- function(
)
)

assert_that(
!is.na(seed),
!is.null(seed),
is.numeric(seed),
msg = "mcmc seed is invalid"
)
sampling_args$seed <- seed
sampling_args$seed <- sample.int(.Machine$integer.max, 1)

stan_fit <- record({
do.call(sampling, sampling_args)
Expand Down
19 changes: 12 additions & 7 deletions R/methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@
#' @param type a character string that specifies the resampling method used to perform inference
#' when a conditional mean imputation approach (set via `method_condmean()`) is used. Must be one of `"bootstrap"` or `"jackknife"`.
#'
#' @param seed a numeric that specifies the seed to be used in the call to Stan. This
#' argument is passed onto the `seed` argument of [rstan::sampling()]. Note that
#' this is only required for `method_bayes()`, for all other methods you can achieve
#' reproducible results by setting the seed via `set.seed()`. See details.
#' @param seed deprecated. Please use `set.seed()` instead.
#'
#' @details
#'
Expand Down Expand Up @@ -93,14 +90,22 @@ method_bayes <- function(
burn_between = 50,
same_cov = TRUE,
n_samples = 20,
seed = sample.int(.Machine$integer.max, 1)
seed = NULL
) {
assertthat::assert_that(
is.null(seed),
msg = paste(
"The `seed` argument to `method_bayes()` has been deprecated;",
"please use `set.seed()` instead.",
collapse = " "
)
)

x <- list(
burn_in = burn_in,
burn_between = burn_between,
same_cov = same_cov,
n_samples = n_samples,
seed = seed
n_samples = n_samples
)
return(as_class(x, c("method", "bayes")))
}
Expand Down
4 changes: 2 additions & 2 deletions data-raw/create_print_test_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,14 @@ set.seed(413)
dobj <- get_data(40)

suppressWarnings({
set.seed(859)
drawobj_b <- draws(
data = dobj$dat,
data_ice = dobj$dat_ice,
vars = dobj$vars,
method = method_bayes(
n_samples = 50,
burn_between = 1,
seed = 859
burn_between = 1
)
)
})
Expand Down
7 changes: 2 additions & 5 deletions man/method.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion tests/testthat/_snaps/print.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@
burn_between: 1
same_cov: TRUE
n_samples: 50
seed: 859

---
Expand Down
44 changes: 12 additions & 32 deletions tests/testthat/test-mcmc.R
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,7 @@ test_that("fit_mcmc can recover known values with same_cov = FALSE", {
n_samples = 250,
burn_in = 100,
burn_between = 3,
same_cov = FALSE,
seed = 8931
same_cov = FALSE
)

### No missingness
Expand Down Expand Up @@ -604,36 +603,17 @@ test_that("fit_mcmc can recover known values with same_cov = FALSE", {
})


test_that("invalid seed throws an error", {

set.seed(301)
sigma <- as_vcov(c(6, 4, 4), c(0.5, 0.2, 0.3))
dat <- get_sim_data(50, sigma)

dat_ice <- dat %>%
group_by(id) %>%
arrange(desc(visit)) %>%
slice(1) %>%
ungroup() %>%
mutate(strategy = "MAR")

vars <- set_vars(
visit = "visit",
subjid = "id",
group = "group",
covariates = "sex",
strategy = "strategy",
outcome = "outcome"
)

test_that("seed argument to method_bayes is deprecated", {
expect_error(
draws(
dat,
dat_ice,
vars,
method_bayes(n_samples = 2, seed = NA),
quiet = TRUE
),
regexp = "mcmc seed is invalid"
{
method <- method_bayes(
n_samples = 250,
burn_in = 100,
burn_between = 3,
same_cov = FALSE,
seed = 1234
)
},
regexp = "seed.*deprecated"
)
})
3 changes: 1 addition & 2 deletions tests/testthat/test-print.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ test_that("print - bayesian", {
vars = dobj$vars,
method = method_bayes(
n_samples = 50,
burn_between = 1,
seed = 859
burn_between = 1
),
quiet = TRUE
)
Expand Down
9 changes: 4 additions & 5 deletions tests/testthat/test-reproducibility.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ test_that("Results are Reproducible", {



test_that("bayes - seed argument works without set.seed", {
test_that("bayes - set.seed produces identical results", {

sigma <- as_vcov(c(2, 1, 0.7), c(0.5, 0.3, 0.2))
dat <- get_sim_data(200, sigma, trt = 8) %>%
Expand All @@ -111,17 +111,16 @@ test_that("bayes - seed argument works without set.seed", {
)

meth <- method_bayes(
seed = 1482,
burn_between = 5,
burn_in = 200,
n_samples = 2
n_samples = 6
)

set.seed(49812)
set.seed(1234)
x <- suppressWarnings({
draws(dat, dat_ice, vars, meth, quiet = TRUE)
})
set.seed(2414)
set.seed(1234)
y <- suppressWarnings({
draws(dat, dat_ice, vars, meth, quiet = TRUE)
})
Expand Down
2 changes: 1 addition & 1 deletion vignettes/advanced.html
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ <h1><span class="header-section-number">6</span> Custom imputation strategies</h
<span id="cb6-17"><a href="#cb6-17" tabindex="-1"></a><span class="co">#&gt; pars &lt;- list(mu = mu, sigma = sigma)</span></span>
<span id="cb6-18"><a href="#cb6-18" tabindex="-1"></a><span class="co">#&gt; return(pars)</span></span>
<span id="cb6-19"><a href="#cb6-19" tabindex="-1"></a><span class="co">#&gt; }</span></span>
<span id="cb6-20"><a href="#cb6-20" tabindex="-1"></a><span class="co">#&gt; &lt;bytecode: 0x7ff37e6af218&gt;</span></span>
<span id="cb6-20"><a href="#cb6-20" tabindex="-1"></a><span class="co">#&gt; &lt;bytecode: 0x7f86686ebac0&gt;</span></span>
<span id="cb6-21"><a href="#cb6-21" tabindex="-1"></a><span class="co">#&gt; &lt;environment: namespace:rbmi&gt;</span></span></code></pre></div>
<p>To further illustrate this for a simple example, assume that a new strategy is to be implemented as follows:
- The marginal mean of the imputation distribution is equal to the marginal mean trajectory for the subject according to their assigned group and covariates up to the ICE.
Expand Down
6 changes: 2 additions & 4 deletions vignettes/quickstart.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ vars <- set_vars(
method <- method_bayes(
burn_in = 200,
burn_between = 5,
n_samples = 150,
seed = 675442751
n_samples = 150
)
# Create samples for the imputation parameters by running the draws() function
Expand Down Expand Up @@ -347,8 +346,7 @@ vars <- set_vars(
method <- method_bayes(
burn_in = 200,
burn_between = 5,
n_samples = 150,
seed = 675442751
n_samples = 150
)
Expand Down
Loading

0 comments on commit c6d0b82

Please sign in to comment.