diff --git a/mne/stats/cluster_level.py b/mne/stats/cluster_level.py index 804d035ff51..141f7c299d4 100644 --- a/mne/stats/cluster_level.py +++ b/mne/stats/cluster_level.py @@ -24,13 +24,15 @@ from scipy.stats import f as fstat from scipy.stats import t as tstat -from .. import BaseEpochs, Evoked, EvokedArray +from ..epochs import BaseEpochs, EvokedArray +from ..evoked import Evoked from ..fixes import has_numba, jit from ..parallel import parallel_func from ..source_estimate import MixedSourceEstimate, SourceEstimate, VolSourceEstimate from ..source_space import SourceSpaces from ..time_frequency import BaseTFR from ..utils import ( + GetEpochsMixin, ProgressBar, _check_option, _pl, @@ -1784,7 +1786,11 @@ def _validate_cluster_df(df: pd.DataFrame, dv_name: str, iv_name: str): f"{prologue} consistent shape, but {len(all_shapes)} different " f"shapes were found: {'; '.join(all_shapes)}." ) - return all_types.pop() # return the type of the data column entries + obj_type = all_types.pop() + is_epo = GetEpochsMixin in obj_type.__mro__ + is_tfr = BaseTFR in obj_type.__mro__ + is_arr = np.ndarray in obj_type.__mro__ + return is_epo, is_tfr, is_arr @verbose @@ -1868,7 +1874,7 @@ def cluster_test( iv_name = str(np.array(formula.rhs.root).item()) # validate the input dataframe and return the type of the data column entries - _dtype = _validate_cluster_df(df, dv_name, iv_name) + is_epo, is_tfr, is_arr = _validate_cluster_df(df, dv_name, iv_name) # for within_subject designs, check if each subject has 2 observations _validate_type(within_id, (str, None), "within_id") @@ -1880,23 +1886,18 @@ def cluster_test( raise ValueError("for paired t-test, each subject must have 2 observations") # extract the data from the dataframe - def _extract_data_array(series): + outer_func = np.concatenate if is_epo else np.array + axes = (-3, -1) if is_tfr else (-2, -1) + + def func_arr(series): return np.concatenate(series.values) - def _extract_data_mne(series): # 2D data - return np.array( - series.map(lambda inst: inst.get_data().swapaxes(-2, -1)).to_list() + def func_mne(series): + return outer_func( + series.map(lambda inst: inst.get_data().swapaxes(*axes)).to_list() ) - def _extract_data_tfr(series): - return series.map(lambda inst: inst.get_data().swapaxes(-3, -1)).to_list() - - if _dtype is np.ndarray: - func = _extract_data_array - elif _dtype is BaseTFR: - func = _extract_data_tfr - else: - func = _extract_data_mne + func = func_arr if is_arr else func_mne # convert to a list-like X for clustering X = df.groupby(iv_name).agg({dv_name: func})[dv_name].to_list() @@ -1993,7 +1994,7 @@ def plot_cluster_time_sensor( linestyles: list | dict | None = None, cmap_evokeds: None | str | tuple = None, cmap_topo: None | str | tuple = None, - ci: float | bool | callable() | None = None, + ci: float | bool | callable | None = None, ): """ Plot the cluster with the lowest p-value. diff --git a/mne/stats/tests/test_cluster_level.py b/mne/stats/tests/test_cluster_level.py index e3a701d3691..b4d676abe91 100644 --- a/mne/stats/tests/test_cluster_level.py +++ b/mne/stats/tests/test_cluster_level.py @@ -39,8 +39,8 @@ summarize_clusters_stc, ttest_1samp_no_p, ) -from mne.time_frequency import AverageTFRArray, EpochsTFRArray -from mne.utils import _record_warnings, catch_logging +from mne.time_frequency import AverageTFRArray, BaseTFR, EpochsTFRArray +from mne.utils import GetEpochsMixin, _record_warnings, catch_logging n_space = 50 @@ -911,121 +911,77 @@ def test_new_cluster_api(Inst): """Test handling different MNE objects in the cluster API.""" pd = pytest.importorskip("pandas") - n_subs, n_epo, n_chan, n_freq, n_times = 2, 2, 3, 4, 5 - info = create_info(ch_names=n_chan, sfreq=1000, ch_types="eeg") - # Introduce a significant difference in a specific region, time, and frequency - region_start = 1 - region_end = 2 - time_start = 2 - time_end = 4 - freq_start = 2 - freq_end = 4 - - if Inst == EpochsArray: - # Create random data for EpochsArray - inst1 = Inst(np.random.randn(n_epo, n_chan, n_times), info=info) - # Adding a constant to create a difference - data_copy = inst1.get_data().copy() # no data attribute for EpochsArray - data_copy[:, region_start:region_end, time_start:time_end] += ( - 2 # Modify the copy - ) - inst2 = Inst( - data=data_copy, info=info - ) # Use the modified copy as a new instance - - elif Inst == EvokedArray: - # Create random data for EvokedArray - inst1 = Inst(np.random.randn(n_chan, n_times), info=info) - data_copy = inst1.data.copy() - data_copy[region_start:region_end, time_start:time_end] += 2 - inst2 = Inst(data=data_copy, info=info) - - elif Inst == EpochsTFRArray: - # Create random data for EpochsTFRArray - data_tfr1 = np.random.randn(n_epo, n_chan, n_freq, n_times) - data_tfr2 = np.random.randn(n_epo, n_chan, n_freq, n_times) - inst1 = Inst( - data=data_tfr1, info=info, times=np.arange(n_times), freqs=np.arange(n_freq) - ) - inst2 = Inst( - data=data_tfr2, info=info, times=np.arange(n_times), freqs=np.arange(n_freq) - ) - data_tfr2 = inst2.data.copy() - data_tfr2[ - :, region_start:region_end, freq_start:freq_end, time_start:time_end - ] += 2 - inst2 = Inst( - data=data_tfr2, info=info, times=np.arange(n_times), freqs=np.arange(n_freq) - ) + rng = np.random.default_rng(seed=8675309) + is_epo = GetEpochsMixin in Inst.__mro__ + is_tfr = BaseTFR in Inst.__mro__ - elif Inst == AverageTFRArray: - # Create random data for AverageTFRArray - data_tfr1 = np.random.randn(n_chan, n_freq, n_times) - data_tfr2 = np.random.randn(n_chan, n_freq, n_times) - inst1 = Inst( - data=data_tfr1, info=info, times=np.arange(n_times), freqs=np.arange(n_freq) - ) - inst2 = Inst( - data=data_tfr2, info=info, times=np.arange(n_times), freqs=np.arange(n_freq) - ) - data_tfr2 = inst2.data.copy() - data_tfr2[ - region_start:region_end, freq_start:freq_end, time_start:time_end - ] += 2 - inst2 = Inst( - data=data_tfr2, info=info, times=np.arange(n_times), freqs=np.arange(n_freq) - ) + n_epo, n_chan, n_freq, n_times = 6, 3, 4, 5 - if Inst == EvokedArray or Inst == AverageTFRArray: - # Generate random noise - noise = np.random.normal(loc=0, scale=0.1, size=inst1.data.shape) - # add noise to the data of the second subject - inst1_n = inst1.copy() - inst1_n.data = inst1.data + noise - inst2_n = inst2.copy() - inst2_n.data = inst2.data + noise - data = [inst1, inst2, inst1_n, inst2_n] - conds = ["a", "b"] * n_subs + # prepare the dimensions of the simulated data, then simulate + size = (n_chan,) + if is_epo: + size = (n_epo, *size) + if is_tfr: + size = (*size, n_freq) + size = (*size, n_times) + data = rng.normal(size=size) + + # construct the instance + info = create_info(ch_names=n_chan, sfreq=1000, ch_types="eeg") + kw = dict(times=np.arange(n_times), freqs=np.arange(n_freq)) if is_tfr else dict() + cond_a = Inst(data=data, info=info, **kw) + cond_b = cond_a.copy() + # introduce a significant difference in a specific region, time, and frequency + ch_start, ch_end = 0, 2 # 2 channels + t_start, t_end = 2, 4 # 2 times + f_start, f_end = 2, 4 # 2 freqs + if is_tfr: + cond_b._data[..., ch_start:ch_end, f_start:f_end, t_start:t_end] += 2 else: - data = [inst1, inst2] + cond_b._data[..., ch_start:ch_end, t_start:t_end] += 2 + # for Evokeds/AverageTFRs, we create fake "subjects" as our observations within each + # condition. We add a bit of noise while we do so. + if not is_epo: + insts = list() + for cond in cond_a, cond_b: + for _n in range(n_epo): + if not _n: + insts.append(cond) + continue + _cond = cond.copy() + _cond.data += rng.normal(scale=0.1, size=_cond.data.shape) + insts.append(_cond) + conds = np.repeat(["a", "b"], n_epo).tolist() + else: + # For Epochs(TFR)Array, each epoch is an observation and they're already + # noisy/non-identical, so no duplication / noise-addition necessary. + insts = [cond_a, cond_b] conds = ["a", "b"] - df = pd.DataFrame(dict(data=data, condition=conds)) - + # run new clustering API + df = pd.DataFrame(dict(data=insts, condition=conds)) kwargs = dict( n_permutations=100, seed=42, tail=1, buffer_size=None, out_type="mask" ) - result_new_api = cluster_test(df, "data~condition", **kwargs) # make sure channels are last dimension for old API - if Inst == EpochsArray: - inst1 = inst1.get_data().transpose(0, 2, 1) - inst2 = inst2.get_data().transpose(0, 2, 1) - elif Inst == EpochsTFRArray: - inst1 = inst1.data.transpose(0, 3, 2, 1) - inst2 = inst2.data.transpose(0, 3, 2, 1) - elif Inst == AverageTFRArray: - inst1 = inst1.data.transpose(2, 1, 0) - inst2 = inst2.data.transpose(2, 1, 0) - inst1_n = inst1_n.data.transpose(2, 1, 0) - inst2_n = inst2_n.data.transpose(2, 1, 0) - # combine the data of the two subjects - inst1 = np.concatenate([inst1[np.newaxis, :], inst1_n[np.newaxis, :]], axis=0) - inst2 = np.concatenate([inst2[np.newaxis, :], inst2_n[np.newaxis, :]], axis=0) + if is_epo: + axes = (0, 3, 2, 1) if is_tfr else (0, 2, 1) + X = [cond_a.get_data().transpose(*axes), cond_b.get_data().transpose(*axes)] else: - inst1 = inst1.data.transpose(1, 0) - inst2 = inst2.data.transpose(1, 0) - inst1_n = inst1_n.data.transpose(1, 0) - inst2_n = inst2_n.data.transpose(1, 0) - # combine the data of the two subjects - inst1 = np.concatenate([inst1[np.newaxis, :], inst1_n[np.newaxis, :]], axis=0) - inst2 = np.concatenate([inst2[np.newaxis, :], inst2_n[np.newaxis, :]], axis=0) - - F_obs, clusters, cluster_pvals, H0 = permutation_cluster_test( - [inst1, inst2], **kwargs - ) + axes = (2, 1, 0) if is_tfr else (1, 0) + Xa = list() + Xb = list() + for inst, cond in zip(insts, conds): + container = Xa if cond == "a" else Xb + container.append(inst.get_data().transpose(*axes)) + X = [np.stack(Xa), np.stack(Xb)] + + F_obs, clusters, cluster_pvals, H0 = permutation_cluster_test(X, **kwargs) assert_array_almost_equal(result_new_api.H0, H0) assert_array_almost_equal(result_new_api.stat_obs, F_obs) assert_array_almost_equal(result_new_api.cluster_p_values, cluster_pvals) - assert result_new_api.clusters == clusters + assert len(result_new_api.clusters) == len(clusters) + for clu1, clu2 in zip(result_new_api.clusters, clusters): + assert_array_equal(clu1, clu2)