Skip to content

Commit

Permalink
Merge pull request #5 from drammock/new_cluster_stats_api_GSOC24
Browse files Browse the repository at this point in the history
fixes for test
  • Loading branch information
CarinaFo authored Aug 21, 2024
2 parents b5fce8b + a288d85 commit 81ce0d0
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 122 deletions.
35 changes: 18 additions & 17 deletions mne/stats/cluster_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
from scipy.stats import f as fstat
from scipy.stats import t as tstat

from .. import BaseEpochs, Evoked, EvokedArray
from ..epochs import BaseEpochs, EvokedArray
from ..evoked import Evoked
from ..fixes import has_numba, jit
from ..parallel import parallel_func
from ..source_estimate import MixedSourceEstimate, SourceEstimate, VolSourceEstimate
from ..source_space import SourceSpaces
from ..time_frequency import BaseTFR
from ..utils import (
GetEpochsMixin,
ProgressBar,
_check_option,
_pl,
Expand Down Expand Up @@ -1784,7 +1786,11 @@ def _validate_cluster_df(df: pd.DataFrame, dv_name: str, iv_name: str):
f"{prologue} consistent shape, but {len(all_shapes)} different "
f"shapes were found: {'; '.join(all_shapes)}."
)
return all_types.pop() # return the type of the data column entries
obj_type = all_types.pop()
is_epo = GetEpochsMixin in obj_type.__mro__
is_tfr = BaseTFR in obj_type.__mro__
is_arr = np.ndarray in obj_type.__mro__
return is_epo, is_tfr, is_arr


@verbose
Expand Down Expand Up @@ -1868,7 +1874,7 @@ def cluster_test(
iv_name = str(np.array(formula.rhs.root).item())

# validate the input dataframe and return the type of the data column entries
_dtype = _validate_cluster_df(df, dv_name, iv_name)
is_epo, is_tfr, is_arr = _validate_cluster_df(df, dv_name, iv_name)

# for within_subject designs, check if each subject has 2 observations
_validate_type(within_id, (str, None), "within_id")
Expand All @@ -1880,23 +1886,18 @@ def cluster_test(
raise ValueError("for paired t-test, each subject must have 2 observations")

# extract the data from the dataframe
def _extract_data_array(series):
outer_func = np.concatenate if is_epo else np.array
axes = (-3, -1) if is_tfr else (-2, -1)

def func_arr(series):
return np.concatenate(series.values)

def _extract_data_mne(series): # 2D data
return np.array(
series.map(lambda inst: inst.get_data().swapaxes(-2, -1)).to_list()
def func_mne(series):
return outer_func(
series.map(lambda inst: inst.get_data().swapaxes(*axes)).to_list()
)

def _extract_data_tfr(series):
return series.map(lambda inst: inst.get_data().swapaxes(-3, -1)).to_list()

if _dtype is np.ndarray:
func = _extract_data_array
elif _dtype is BaseTFR:
func = _extract_data_tfr
else:
func = _extract_data_mne
func = func_arr if is_arr else func_mne

# convert to a list-like X for clustering
X = df.groupby(iv_name).agg({dv_name: func})[dv_name].to_list()
Expand Down Expand Up @@ -1993,7 +1994,7 @@ def plot_cluster_time_sensor(
linestyles: list | dict | None = None,
cmap_evokeds: None | str | tuple = None,
cmap_topo: None | str | tuple = None,
ci: float | bool | callable() | None = None,
ci: float | bool | callable | None = None,
):
"""
Plot the cluster with the lowest p-value.
Expand Down
166 changes: 61 additions & 105 deletions mne/stats/tests/test_cluster_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
summarize_clusters_stc,
ttest_1samp_no_p,
)
from mne.time_frequency import AverageTFRArray, EpochsTFRArray
from mne.utils import _record_warnings, catch_logging
from mne.time_frequency import AverageTFRArray, BaseTFR, EpochsTFRArray
from mne.utils import GetEpochsMixin, _record_warnings, catch_logging

n_space = 50

Expand Down Expand Up @@ -911,121 +911,77 @@ def test_new_cluster_api(Inst):
"""Test handling different MNE objects in the cluster API."""
pd = pytest.importorskip("pandas")

n_subs, n_epo, n_chan, n_freq, n_times = 2, 2, 3, 4, 5
info = create_info(ch_names=n_chan, sfreq=1000, ch_types="eeg")
# Introduce a significant difference in a specific region, time, and frequency
region_start = 1
region_end = 2
time_start = 2
time_end = 4
freq_start = 2
freq_end = 4

if Inst == EpochsArray:
# Create random data for EpochsArray
inst1 = Inst(np.random.randn(n_epo, n_chan, n_times), info=info)
# Adding a constant to create a difference
data_copy = inst1.get_data().copy() # no data attribute for EpochsArray
data_copy[:, region_start:region_end, time_start:time_end] += (
2 # Modify the copy
)
inst2 = Inst(
data=data_copy, info=info
) # Use the modified copy as a new instance

elif Inst == EvokedArray:
# Create random data for EvokedArray
inst1 = Inst(np.random.randn(n_chan, n_times), info=info)
data_copy = inst1.data.copy()
data_copy[region_start:region_end, time_start:time_end] += 2
inst2 = Inst(data=data_copy, info=info)

elif Inst == EpochsTFRArray:
# Create random data for EpochsTFRArray
data_tfr1 = np.random.randn(n_epo, n_chan, n_freq, n_times)
data_tfr2 = np.random.randn(n_epo, n_chan, n_freq, n_times)
inst1 = Inst(
data=data_tfr1, info=info, times=np.arange(n_times), freqs=np.arange(n_freq)
)
inst2 = Inst(
data=data_tfr2, info=info, times=np.arange(n_times), freqs=np.arange(n_freq)
)
data_tfr2 = inst2.data.copy()
data_tfr2[
:, region_start:region_end, freq_start:freq_end, time_start:time_end
] += 2
inst2 = Inst(
data=data_tfr2, info=info, times=np.arange(n_times), freqs=np.arange(n_freq)
)
rng = np.random.default_rng(seed=8675309)
is_epo = GetEpochsMixin in Inst.__mro__
is_tfr = BaseTFR in Inst.__mro__

elif Inst == AverageTFRArray:
# Create random data for AverageTFRArray
data_tfr1 = np.random.randn(n_chan, n_freq, n_times)
data_tfr2 = np.random.randn(n_chan, n_freq, n_times)
inst1 = Inst(
data=data_tfr1, info=info, times=np.arange(n_times), freqs=np.arange(n_freq)
)
inst2 = Inst(
data=data_tfr2, info=info, times=np.arange(n_times), freqs=np.arange(n_freq)
)
data_tfr2 = inst2.data.copy()
data_tfr2[
region_start:region_end, freq_start:freq_end, time_start:time_end
] += 2
inst2 = Inst(
data=data_tfr2, info=info, times=np.arange(n_times), freqs=np.arange(n_freq)
)
n_epo, n_chan, n_freq, n_times = 6, 3, 4, 5

if Inst == EvokedArray or Inst == AverageTFRArray:
# Generate random noise
noise = np.random.normal(loc=0, scale=0.1, size=inst1.data.shape)
# add noise to the data of the second subject
inst1_n = inst1.copy()
inst1_n.data = inst1.data + noise
inst2_n = inst2.copy()
inst2_n.data = inst2.data + noise
data = [inst1, inst2, inst1_n, inst2_n]
conds = ["a", "b"] * n_subs
# prepare the dimensions of the simulated data, then simulate
size = (n_chan,)
if is_epo:
size = (n_epo, *size)
if is_tfr:
size = (*size, n_freq)
size = (*size, n_times)
data = rng.normal(size=size)

# construct the instance
info = create_info(ch_names=n_chan, sfreq=1000, ch_types="eeg")
kw = dict(times=np.arange(n_times), freqs=np.arange(n_freq)) if is_tfr else dict()
cond_a = Inst(data=data, info=info, **kw)
cond_b = cond_a.copy()
# introduce a significant difference in a specific region, time, and frequency
ch_start, ch_end = 0, 2 # 2 channels
t_start, t_end = 2, 4 # 2 times
f_start, f_end = 2, 4 # 2 freqs
if is_tfr:
cond_b._data[..., ch_start:ch_end, f_start:f_end, t_start:t_end] += 2
else:
data = [inst1, inst2]
cond_b._data[..., ch_start:ch_end, t_start:t_end] += 2
# for Evokeds/AverageTFRs, we create fake "subjects" as our observations within each
# condition. We add a bit of noise while we do so.
if not is_epo:
insts = list()
for cond in cond_a, cond_b:
for _n in range(n_epo):
if not _n:
insts.append(cond)
continue
_cond = cond.copy()
_cond.data += rng.normal(scale=0.1, size=_cond.data.shape)
insts.append(_cond)
conds = np.repeat(["a", "b"], n_epo).tolist()
else:
# For Epochs(TFR)Array, each epoch is an observation and they're already
# noisy/non-identical, so no duplication / noise-addition necessary.
insts = [cond_a, cond_b]
conds = ["a", "b"]

df = pd.DataFrame(dict(data=data, condition=conds))

# run new clustering API
df = pd.DataFrame(dict(data=insts, condition=conds))
kwargs = dict(
n_permutations=100, seed=42, tail=1, buffer_size=None, out_type="mask"
)

result_new_api = cluster_test(df, "data~condition", **kwargs)

# make sure channels are last dimension for old API
if Inst == EpochsArray:
inst1 = inst1.get_data().transpose(0, 2, 1)
inst2 = inst2.get_data().transpose(0, 2, 1)
elif Inst == EpochsTFRArray:
inst1 = inst1.data.transpose(0, 3, 2, 1)
inst2 = inst2.data.transpose(0, 3, 2, 1)
elif Inst == AverageTFRArray:
inst1 = inst1.data.transpose(2, 1, 0)
inst2 = inst2.data.transpose(2, 1, 0)
inst1_n = inst1_n.data.transpose(2, 1, 0)
inst2_n = inst2_n.data.transpose(2, 1, 0)
# combine the data of the two subjects
inst1 = np.concatenate([inst1[np.newaxis, :], inst1_n[np.newaxis, :]], axis=0)
inst2 = np.concatenate([inst2[np.newaxis, :], inst2_n[np.newaxis, :]], axis=0)
if is_epo:
axes = (0, 3, 2, 1) if is_tfr else (0, 2, 1)
X = [cond_a.get_data().transpose(*axes), cond_b.get_data().transpose(*axes)]
else:
inst1 = inst1.data.transpose(1, 0)
inst2 = inst2.data.transpose(1, 0)
inst1_n = inst1_n.data.transpose(1, 0)
inst2_n = inst2_n.data.transpose(1, 0)
# combine the data of the two subjects
inst1 = np.concatenate([inst1[np.newaxis, :], inst1_n[np.newaxis, :]], axis=0)
inst2 = np.concatenate([inst2[np.newaxis, :], inst2_n[np.newaxis, :]], axis=0)

F_obs, clusters, cluster_pvals, H0 = permutation_cluster_test(
[inst1, inst2], **kwargs
)
axes = (2, 1, 0) if is_tfr else (1, 0)
Xa = list()
Xb = list()
for inst, cond in zip(insts, conds):
container = Xa if cond == "a" else Xb
container.append(inst.get_data().transpose(*axes))
X = [np.stack(Xa), np.stack(Xb)]

F_obs, clusters, cluster_pvals, H0 = permutation_cluster_test(X, **kwargs)
assert_array_almost_equal(result_new_api.H0, H0)
assert_array_almost_equal(result_new_api.stat_obs, F_obs)
assert_array_almost_equal(result_new_api.cluster_p_values, cluster_pvals)
assert result_new_api.clusters == clusters
assert len(result_new_api.clusters) == len(clusters)
for clu1, clu2 in zip(result_new_api.clusters, clusters):
assert_array_equal(clu1, clu2)

0 comments on commit 81ce0d0

Please sign in to comment.