Skip to content

Commit

Permalink
Convert rstan to suggest (#441)
Browse files Browse the repository at this point in the history
Closes #400
  • Loading branch information
gowerc authored Oct 8, 2024
1 parent 5db898e commit 6368d34
Show file tree
Hide file tree
Showing 27 changed files with 194 additions and 181 deletions.
37 changes: 8 additions & 29 deletions .github/actions/build-src/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,48 +8,27 @@ inputs:
runs:
using: "composite"
steps:

- name: Determine system/package state
run: |
sink(".github/meta.txt")
list(
version = version,
pkgs = installed.packages()[c("Rcpp", "rstan", "rstantools"), c("Version", "Package")]
pkgs = installed.packages()[c("rstan"), c("Version", "Package")]
)
sink()
shell: Rscript {0}

# The default R/stanmodels.R was updated in v2.2.0 (11Apr2022)
# Normally for older versions it would change itself on the fly
# when it compiles the stan model
# but as we are using cached models we are preventing compilation
# and thus need to manually refresh the file
- name: Refresh rstantools config files
shell: bash
run: |
echo "inputs.refresh = ${{ inputs.refresh }}"
if ${{ inputs.refresh }} ; then
Rscript -e "rstantools::rstan_config()"
fi

- name: Cache Compiled Stan Code
id: cache-pkgs
uses: actions/cache@v3
with:
path: src/*
path: local/*
key: ${{ hashFiles('.github/meta.txt') }}-${{ hashFiles('inst/stan/MMRM.stan') }}


# pkgbuild compares time stamps of *.so object to all header files including
# inst/include/stan_meta_header.hpp so we touch the .so object to push its time
# stamp beyond that of the .hpp file

- name: Build if needed
shell: bash
env:
RBMI_CACHE_DIR: local
run: |
if [[ ${{ steps.cache-pkgs.outputs.cache-hit == 'true' }} && "${{ runner.os }}" != "Windows" ]] ; then
echo "No compilation needed!"
touch src/*.so
else
echo "Compilation needed!"
Rscript -e "pkgbuild::compile_dll()"
fi
Rscript -e "pkgload::load_all(); get_stan_model()"
10 changes: 1 addition & 9 deletions .github/actions/rcmdcheck/action.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@

name: 'Build src'
description: 'Build src!'
inputs:
slim:
description: 'Should the reduced set of checks be run'
default: false
runs:
using: "composite"
steps:
Expand All @@ -15,16 +11,12 @@ runs:
- name: Run R CMD check
env:
RBMI_CACHE_DIR: local
_R_CHECK_CRAN_INCOMING_: false
_R_CHECK_CRAN_INCOMING_REMOTE_: false
shell: bash
run: |
echo "inputs.slim = ${{ inputs.slim }}"
if ${{ inputs.slim }} ; then
R CMD check --no-manual --no-build-vignettes --no-vignettes --ignore-vignettes *.tar.gz
else
R CMD check --no-manual --as-cran *.tar.gz
fi
- name: Catch warnings in R CMD check output
id: catch-errors
Expand Down
2 changes: 0 additions & 2 deletions .github/workflows/on_biweekly.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,4 @@ jobs:

- name: Check
uses: ./.github/actions/rcmdcheck
with:
slim: true

4 changes: 2 additions & 2 deletions .github/workflows/on_pr_main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ jobs:

- name: Check
uses: ./.github/actions/rcmdcheck
with:
slim: true


vignettes:
Expand All @@ -53,5 +51,7 @@ jobs:
uses: ./.github/actions/build-src

- name: Build Vignettes
env:
RBMI_CACHE_DIR: local
run: |
Rscript ./vignettes/build.R
2 changes: 2 additions & 0 deletions .github/workflows/on_push.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ jobs:
uses: ./.github/actions/build-src

- name: testthat
env:
RBMI_CACHE_DIR: local
run: |
options(crayon.enabled = TRUE, cli.dynamic = FALSE)
devtools::test(stop_on_failure = TRUE, reporter = testthat::CheckReporter)
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,4 @@ docs


local/
*.rds
5 changes: 3 additions & 2 deletions .lintr
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
linters: with_defaults(
linters: linters_with_defaults(
line_length_linter(120),
object_name_linter = NULL
object_name_linter = NULL,
indentation_linter(indent = 4L)
)
39 changes: 17 additions & 22 deletions .vscode/tasks.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,25 @@
}
},
{
"label": "R - testthat (FULL)",
"label": "rbmi - testthat (local cache)",
"problemMatcher": "$testthat",
"command": "Rscript",
"args" : [
"-e",
"devtools::test()"
],
"options": {
"env": {
"RBMI_CACHE_DIR" : "${workspaceFolder}/local"
}
},
},
{
"label": "rbmi - testthat (FULL)",
"command": "Rscript",
"options": {
"env": {
"RBMI_CACHE_DIR" : "${workspaceFolder}/local",
"R_TEST_FULL" : "TRUE"
}
},
Expand All @@ -34,27 +49,7 @@
"clear": true,
"panel": "dedicated"
},
"problemMatcher": {
"owner": "R-testthat",
"severity": "error",
"fileLocation": [
"relative",
"${workspaceFolder}/tests/testthat"
],
"pattern": [
{
"regexp": "^(Failure|Error)\\s\\((.*\\.[Rr]):(\\d+):(\\d+)\\):\\s(.*)",
"file": 2,
"line": 3,
"column": 4,
"message": 5
},
{
"regexp": "^(.*)$",
"message": 1
}
]
}
"problemMatcher": "$testthat"
}
]
}
15 changes: 2 additions & 13 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,17 @@ Suggests:
lubridate,
purrr,
ggplot2,
rstan (>= 2.26.0),
R.rsp
Biarch: true
Config/testthat/edition: 3
Imports:
mmrm,
pkgload,
Matrix,
tools,
methods,
Rcpp (>= 0.12.0),
RcppParallel (>= 5.0.1),
rstan (>= 2.26.0),
rstantools (>= 2.1.1),
R6,
assertthat
LinkingTo:
BH (>= 1.66.0),
Rcpp (>= 0.12.0),
RcppEigen (>= 0.3.3.3.0),
RcppParallel (>= 5.0.1),
rstan (>= 2.26.0),
StanHeaders (>= 2.26.0)
SystemRequirements: GNU make
Depends:
R (>= 3.4.0)
License: Apache License (>= 2)
Expand Down
5 changes: 0 additions & 5 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,9 @@ export(strategy_MAR)
export(validate)
export(validate_analyse_pars)
import(R6)
import(Rcpp)
import(methods)
importFrom(assertthat,assert_that)
importFrom(mmrm,VarCorr)
importFrom(rstan,extract)
importFrom(rstan,sampling)
importFrom(rstan,summary)
importFrom(stats,aggregate)
importFrom(stats,as.formula)
importFrom(stats,binomial)
Expand All @@ -107,4 +103,3 @@ importFrom(stats,var)
importFrom(stats,vcov)
importFrom(utils,capture.output)
importFrom(utils,relist)
useDynLib(rbmi, .registration = TRUE)
19 changes: 7 additions & 12 deletions R/mcmc.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#' @description
#' `fit_mcmc()` fits the base imputation model using a Bayesian approach.
#' This is done through a MCMC method that is implemented in `stan`
#' and is run by using the function [rstan::sampling()].
#' and is run by using the function `rstan::sampling()`.
#' The function returns the draws from the posterior distribution of the model parameters
#' and the `stanfit` object. Additionally it performs multiple diagnostics checks of the chain
#' and returns warnings in case of any detected issues.
Expand Down Expand Up @@ -45,10 +45,7 @@
#' - `fit`: a `stanfit` object.
#'
#'
#' @import Rcpp
#' @import methods
#' @importFrom rstan sampling
#' @useDynLib rbmi, .registration = TRUE
fit_mcmc <- function(
designmat,
outcome,
Expand Down Expand Up @@ -95,7 +92,7 @@ fit_mcmc <- function(
)

sampling_args <- list(
object = stanmodels$MMRM,
object = get_stan_model(),
data = stan_data,
pars = c("beta", "Sigma"),
chains = 1,
Expand All @@ -116,7 +113,7 @@ fit_mcmc <- function(
sampling_args$seed <- sample.int(.Machine$integer.max, 1)

stan_fit <- record({
do.call(sampling, sampling_args)
do.call(rstan::sampling, sampling_args)
})

if (!is.null(stan_fit$errors)) {
Expand Down Expand Up @@ -218,8 +215,8 @@ split_dim <- function(a, n) {
#' @description
#' Extract draws from a `stanfit` object and convert them into lists.
#'
#' The function [rstan::extract()] returns the draws for a given parameter as an array. This function
#' calls [rstan::extract()] to extract the draws from a `stanfit` object
#' The function `rstan::extract()` returns the draws for a given parameter as an array. This function
#' calls `rstan::extract()` to extract the draws from a `stanfit` object
#' and then convert the arrays into lists.
#'
#' @param stan_fit A `stanfit` object.
Expand All @@ -233,10 +230,9 @@ split_dim <- function(a, n) {
#' of the list is a list with length equal to 1 if `same_cov = TRUE` or equal to the
#' number of groups if `same_cov = FALSE`.
#'
#' @importFrom rstan extract
extract_draws <- function(stan_fit) {

pars <- extract(stan_fit, pars = c("beta", "Sigma"))
pars <- rstan::extract(stan_fit, pars = c("beta", "Sigma"))
names(pars) <- c("beta", "sigma")

##################### from array to list
Expand All @@ -261,7 +257,6 @@ extract_draws <- function(stan_fit) {
#' @return
#' A named vector containing the ESS for each parameter of the model.
#'
#' @importFrom rstan summary
get_ESS <- function(stan_fit) {
return(rstan::summary(stan_fit, pars = c("beta", "Sigma"))$summary[, "n_eff"])
}
Expand Down Expand Up @@ -316,7 +311,7 @@ check_ESS <- function(stan_fit, n_draws, threshold_lowESS = 0.4) {
#' 2. The Bayesian Fraction of Missing Information (BFMI) is sufficiently low.
#' 3. The number of iterations that saturated the max treedepth is zero.
#'
#' Please see [rstan::check_hmc_diagnostics()] for details.
#' Please see `rstan::check_hmc_diagnostics()` for details.
#'
#' @param stan_fit A `stanfit` object.
#'
Expand Down
25 changes: 0 additions & 25 deletions R/stanmodels.R

This file was deleted.

Loading

0 comments on commit 6368d34

Please sign in to comment.