Skip to content

Commit

Permalink
Fix StanMetaRegression estimator (#108)
Browse files Browse the repository at this point in the history
* Fix StanMetaRegression estimator

* add pystan and arviz to required dependencies

* Revert "add pystan and arviz to required dependencies"

This reverts commit 28ef19d.

* Keeping pystan and arviz optional

* Skip PyStan on Python 3.6

* Skip PyStan test for Python < 3.6

* Include warning to the documentation and raise an error in __init__

* Update estimators.py

* Fix `black` issues

* Update test_stan_estimators.py

* declare array

* try adding additional prereqs

* add additional import statement to get error

* fix scipy issue

* add the naked import back in

* fix scipy version

* undo naked import

---------

Co-authored-by: James Kent <[email protected]>
  • Loading branch information
JulioAPeraza and jdkent authored Apr 4, 2024
1 parent 1402199 commit c38b0bb
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 23 deletions.
1 change: 0 additions & 1 deletion pymare/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ class Dataset:
def __init__(
self, y=None, v=None, X=None, n=None, data=None, X_names=None, add_intercept=True
):

if y is None and data is None:
raise ValueError(
"If no y values are provided, a pandas DataFrame "
Expand Down
2 changes: 0 additions & 2 deletions pymare/effectsize/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ class EffectSizeConverter(metaclass=ABCMeta):
"""Base class for effect size converters."""

def __init__(self, data=None, **kwargs):

kwargs = {k: v for k, v in kwargs.items() if v is not None}

if data is not None:
Expand Down Expand Up @@ -509,7 +508,6 @@ def compute_measure(

# Select or infer converter class
if comparison == "infer":

one_samp_inputs = {"m", "sd", "n", "r"}
two_samp_inputs = {"m1", "m2", "sd1", "sd2", "n1", "n2"}

Expand Down
1 change: 0 additions & 1 deletion pymare/effectsize/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def df_search(sym, exprs, known, visited):
results = []

for exp in exp_dict[sym]:

candidates = []

sym_names = set(s.name for s in exp.symbols)
Expand Down
32 changes: 20 additions & 12 deletions pymare/estimators/estimators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Meta-regression estimator classes."""

import sys
from abc import ABCMeta, abstractmethod
from inspect import getfullargspec
from warnings import warn
Expand Down Expand Up @@ -553,8 +554,8 @@ class StanMetaRegression(BaseEstimator):
Warning
-------
With changes to Stan in version 3, which requires Python 3.7, this class no longer works for
Python 3.7+. We will try to fix it in the future.
:obj:`~pymare.estimators.StanMetaRegression` uses Pystan 3, which requires Python 3.7.
Pystan 3 should not be used with PyMARE and Python 3.6 or earlier.
"""

_result_cls = BayesianMetaRegressionResults
Expand All @@ -564,6 +565,13 @@ def __init__(self, **sampling_kwargs):
self.model = None
self.result_ = None

if sys.version_info < (3, 7):
raise RuntimeError(
"StanMetaRegression uses Pystan 3, which requires python 3.7 or higher. "
f"You are running Python {sys.version_info.major}.{sys.version_info.minor}. "
"Pystan 3 should not be used with PyMARE and Python 3.6 or earlier."
)

def compile(self):
"""Compile the Stan model."""
# Note: we deliberately use a centered parameterization for the
Expand All @@ -575,7 +583,7 @@ def compile(self):
int<lower=1> N;
int<lower=1> K;
vector[N] y;
int<lower=1,upper=K> id[N];
array[N] int<lower=1,upper=K> id;
int<lower=1> C;
matrix[K, C] X;
vector[N] sigma;
Expand All @@ -595,13 +603,11 @@ def compile(self):
}
"""
try:
from pystan import StanModel
import stan
except ImportError:
raise ImportError(
"Please install pystan or, if using Python 3.7+, switch to Python 3.6."
)
raise ImportError("Please install pystan.")

self.model = StanModel(model_code=spec)
self.model = stan.build(spec, data=self.data)

def fit(self, y, v, X, groups=None):
"""Run the Stan sampler and return results.
Expand Down Expand Up @@ -645,9 +651,6 @@ def fit(self, y, v, X, groups=None):
"shape {}.".format(y.shape)
)

if self.model is None:
self.compile()

N = y.shape[0]
groups = groups or np.arange(1, N + 1, dtype=int)
K = len(np.unique(groups))
Expand All @@ -662,7 +665,12 @@ def fit(self, y, v, X, groups=None):
"sigma": v.ravel(),
}

self.result_ = self.model.sampling(data=data, **self.sampling_kwargs)
self.data = data

if self.model is None:
self.compile()

self.result_ = self.model.sample(**self.sampling_kwargs)
return self

def summary(self, ci=95):
Expand Down
2 changes: 0 additions & 2 deletions pymare/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,6 @@ def permutation_test(self, n_perm=1000):

# Loop over parallel datasets
for i in range(n_datasets):

y = self.dataset.y[:, i]
y_perm = np.repeat(y[:, None], n_perm, axis=1)

Expand Down Expand Up @@ -472,7 +471,6 @@ def permutation_test(self, n_perm=1000):

# Loop over parallel datasets
for i in range(n_datasets):

y = self.dataset.y[:, i]
y_perm = np.repeat(y[:, None], n_perm, axis=1)

Expand Down
18 changes: 14 additions & 4 deletions pymare/tests/test_stan_estimators.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
"""Tests for estimators that use stan."""

import sys

import pytest

from pymare.estimators import StanMetaRegression


@pytest.mark.skip(reason="StanMetaRegression won't work with Python 3.7+.")
@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python 3.7 or higher")
def test_stan_estimator(dataset):
"""Run smoke test for StanMetaRegression."""
# no ground truth here, so we use sanity checks and rough bounds
est = StanMetaRegression(iter=3000).fit_dataset(dataset)
est = StanMetaRegression(num_samples=3000).fit_dataset(dataset)
results = est.summary()
assert "BayesianMetaRegressionResults" == results.__class__.__name__
summary = results.summary(["beta", "tau2"])
Expand All @@ -19,9 +21,17 @@ def test_stan_estimator(dataset):
assert 3 < tau2 < 5


@pytest.mark.skip(reason="StanMetaRegression won't work with Python 3.7+.")
@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python 3.7 or higher")
def test_stan_2d_input_failure(dataset_2d):
"""Run smoke test for StanMetaRegression on 2D data."""
with pytest.raises(ValueError) as exc:
StanMetaRegression(iter=500).fit_dataset(dataset_2d)
StanMetaRegression(num_samples=500).fit_dataset(dataset_2d)
assert str(exc.value).startswith("The StanMetaRegression")


def test_stan_python_36_failure(dataset):
"""Run smoke test for StanMetaRegression with Python 3.6."""
if sys.version_info < (3, 7):
# Raise error if StanMetaRegression is initialize with python 3.6 or lower
with pytest.raises(RuntimeError):
StanMetaRegression(num_samples=3000).fit_dataset(dataset)
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ python_requires = >= 3.8
install_requires =
numpy>=1.8.0
pandas
scipy
scipy<1.13.0 # https://github.com/arviz-devs/arviz/issues/2336
sympy
wrapt
packages = find:
Expand Down Expand Up @@ -75,6 +75,7 @@ stan =
all =
%(doc)s
%(tests)s
%(stan)s

[options.package_data]
* =
Expand Down

0 comments on commit c38b0bb

Please sign in to comment.