Skip to content

Commit

Permalink
Add an inflation factor to correct for multiple contrasts in Stouffer…
Browse files Browse the repository at this point in the history
…'s combination test (#117)

* Add correction term for multiple contrasts in Stouffer's combination test

* Update RTD yml

* Update .readthedocs.yml

* Update combination.py

* Update .readthedocs.yml

* Update setup.cfg

* Update testing.yml

* Update testing.yml

* Run black

* Make sure solutions and symbols match

* Update combination.py
  • Loading branch information
JulioAPeraza authored Apr 9, 2024
1 parent c38b0bb commit 7465939
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 8 deletions.
76 changes: 68 additions & 8 deletions pymare/estimators/combination.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,77 @@ class StoufferCombinationTest(CombinationTest):
"""

# Maps Dataset attributes onto fit() args; see BaseEstimator for details.
_dataset_attr_map = {"z": "y", "w": "v"}

def fit(self, z, w=None):
"""Fit the estimator to z-values, optionally with weights."""
return super().fit(z, w=w)

def p_value(self, z, w=None):
_dataset_attr_map = {"z": "y", "w": "n", "g": "v"}

def _inflation_term(self, z, w, g):
"""Calculate the variance inflation term for each group.
This term is used to adjust the variance of the combined z-score when
multiple sample come from the same study.
Parameters
----------
z : :obj:`numpy.ndarray` of shape (n, d)
Array of z-values.
w : :obj:`numpy.ndarray` of shape (n, d)
Array of weights.
g : :obj:`numpy.ndarray` of shape (n, d)
Array of group labels.
Returns
-------
sigma : float
The variance inflation term.
"""
# Only center if the samples are not all the same, to prevent division by zero
# when calculating the correlation matrix.
# This centering is problematic for N=2
all_samples_same = np.all(np.equal(z, z[0]), axis=0).all()
z = z if all_samples_same else z - z.mean(0)

# Use the value from one feature, as all features have the same groups and weights
groups = g[:, 0]
weights = w[:, 0]

# Loop over groups
unique_groups = np.unique(groups)

sigma = 0
for group in unique_groups:
group_indices = np.where(groups == group)[0]
group_z = z[group_indices]

# For groups with only one sample the contribution to the summand is 0
n_samples = len(group_indices)
if n_samples < 2:
continue

# Calculate the within group correlation matrix and sum the non-diagonal elements
corr = np.corrcoef(group_z, rowvar=True)
upper_indices = np.triu_indices(n_samples, k=1)
non_diag_corr = corr[upper_indices]
w_i, w_j = weights[upper_indices[0]], weights[upper_indices[1]]

sigma += (2 * w_i * w_j * non_diag_corr).sum()

return sigma

def fit(self, z, w=None, g=None):
"""Fit the estimator to z-values, optionally with weights and groups."""
return super().fit(z, w=w, g=g)

def p_value(self, z, w=None, g=None):
"""Calculate p-values."""
if w is None:
w = np.ones_like(z)
cz = (z * w).sum(0) / np.sqrt((w**2).sum(0))

# Calculate the variance inflation term, sum of non-diagonal elements of sigma.
sigma = self._inflation_term(z, w, g) if g is not None else 0

# The sum of diagonal elements of sigma is given by (w**2).sum(0).
variance = (w**2).sum(0) + sigma

cz = (z * w).sum(0) / np.sqrt(variance)
return ss.norm.sf(cz)


Expand Down
37 changes: 37 additions & 0 deletions pymare/tests/test_combination_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,40 @@ def test_combination_test_from_dataset(Cls, data, mode, expected):
results = est.summary()
z = ss.norm.isf(results.p)
assert np.allclose(z, expected, atol=1e-5)


def test_stouffer_adjusted():
"""Test StoufferCombinationTest with weights and groups."""
# Test with weights and groups
data = np.array(
[
[2.1, 0.7, -0.2, 4.1, 3.8],
[1.1, 0.2, 0.4, 1.3, 1.5],
[-0.6, -1.6, -2.3, -0.8, -4.0],
[2.5, 1.7, 2.1, 2.3, 2.5],
[3.1, 2.7, 3.1, 3.3, 3.5],
[3.6, 3.2, 3.6, 3.8, 4.0],
]
)
weights = np.tile(np.array([4, 3, 4, 10, 15, 10]), (data.shape[1], 1)).T
groups = np.tile(np.array([0, 0, 1, 2, 2, 2]), (data.shape[1], 1)).T

results = StoufferCombinationTest("directed").fit(z=data, w=weights, g=groups).params_
z = ss.norm.isf(results["p"])

z_expected = np.array([5.00088912, 3.70356943, 4.05465924, 5.4633001, 5.18927878])
assert np.allclose(z, z_expected, atol=1e-5)

# Test with weights and no groups. Limiting cases.
# Limiting case 1: all correlations are one.
n_maps_l1 = 5
common_sample = np.array([2.1, 0.7, -0.2])
data_l1 = np.tile(common_sample, (n_maps_l1, 1))
groups_l1 = np.tile(np.array([0, 0, 0, 0, 0]), (data_l1.shape[1], 1)).T

results_l1 = StoufferCombinationTest("directed").fit(z=data_l1, g=groups_l1).params_
z_l1 = ss.norm.isf(results_l1["p"])

sigma_l1 = n_maps_l1 * (n_maps_l1 - 1) # Expected inflation term
z_expected_l1 = n_maps_l1 * common_sample / np.sqrt(n_maps_l1 + sigma_l1)
assert np.allclose(z_l1, z_expected_l1, atol=1e-5)

0 comments on commit 7465939

Please sign in to comment.