From c7947108b44c1580617325ac419c6c88d3333201 Mon Sep 17 00:00:00 2001 From: dominikwelke Date: Wed, 21 Feb 2024 20:57:56 +0000 Subject: [PATCH] DW initial revisions --- mne/channels/channels.py | 4 ++-- mne/epochs.py | 5 ++++- mne/tests/test_epochs.py | 16 ++++++++++++++++ mne/utils/check.py | 4 ++-- 4 files changed, 24 insertions(+), 5 deletions(-) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 9f8462aa956..bfeeaecbba7 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -836,8 +836,8 @@ def interpolate_bads( method : dict | str | None Method to use for each channel type. - - ``"meg"`` channels support ``"MNE"`` (default) and ``"nan" - - ``"eeg"`` channels support ``"spline"`` (default), ``"MNE"`` and ``"nan" + - ``"meg"`` channels support ``"MNE"`` (default) and ``"nan"`` + - ``"eeg"`` channels support ``"spline"`` (default), ``"MNE"`` and ``"nan"`` - ``"fnirs"`` channels support ``"nearest"`` (default) and ``"nan"`` None is an alias for:: diff --git a/mne/epochs.py b/mne/epochs.py index 545b941f008..f178b60b8f8 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -707,6 +707,9 @@ def set_bad_epochs_to_NaN(self, bad_epochs_indices: list = None, verbose=None): List of arrays with indices of bad epochs per channel. verbose : bool, str, int, or None """ + if not self.preload: + raise ValueError("Data must be preloaded.") + if len(bad_epochs_indices) != self.get_data().shape[1]: raise RuntimeError( "The length of the list of bad epochs indices " @@ -1209,7 +1212,7 @@ def _compute_aggregate(self, picks, mode="mean"): n_events += 1 if n_events > 0: - data = np.nanmean(data) + data /= n_events else: data.fill(np.nan) diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 015974e89cc..86a7778e1f7 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -5204,3 +5204,19 @@ def test_empty_error(method, epochs_empty): pytest.importorskip("pandas") with pytest.raises(RuntimeError, match="is empty."): getattr(epochs_empty.copy(), method[0])(**method[1]) + + +def test_set_bad_epochs_to_nan(): + """Test channel specific epoch rejection.""" + # preload=False + raw, ev, _ = _get_data(preload=False) + ep = Epochs(raw, ev, tmin=0, tmax=0.1, baseline=(0, 0)) + bads = [[]] * ep.info["nchan"] + bads[0] = [1] + with pytest.raises(ValueError, match="must be preloaded"): + ep.set_bad_epochs_to_NaN(bads) + + # preload=True + ep.load_data() + ep.set_bad_epochs_to_NaN(bads) + _ = ep.average() diff --git a/mne/utils/check.py b/mne/utils/check.py index 13eca1e0ba0..70b88639132 100644 --- a/mne/utils/check.py +++ b/mne/utils/check.py @@ -901,12 +901,12 @@ def _check_combine(mode, valid=("mean", "median", "std"), axis=0): if mode == "mean": def fun(data): - return np.mean(data, axis=axis) + return np.nanmean(data, axis=axis) elif mode == "std": def fun(data): - return np.std(data, axis=axis) + return np.nanstd(data, axis=axis) elif mode == "median" or mode == np.median: