Skip to content

Commit

Permalink
Merge pull request #3 from dominikwelke/carinafo/channel_specific_epo…
Browse files Browse the repository at this point in the history
…ch_rejection

DW initial revisions
  • Loading branch information
CarinaFo authored Feb 26, 2024
2 parents 1856286 + c794710 commit caac9f4
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 5 deletions.
4 changes: 2 additions & 2 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,8 +836,8 @@ def interpolate_bads(
method : dict | str | None
Method to use for each channel type.
- ``"meg"`` channels support ``"MNE"`` (default) and ``"nan"
- ``"eeg"`` channels support ``"spline"`` (default), ``"MNE"`` and ``"nan"
- ``"meg"`` channels support ``"MNE"`` (default) and ``"nan"``
- ``"eeg"`` channels support ``"spline"`` (default), ``"MNE"`` and ``"nan"``
- ``"fnirs"`` channels support ``"nearest"`` (default) and ``"nan"``
- ``"ecog"`` channels support ``"spline"`` (default) and ``"nan"``
- ``"seeg"`` channels support ``"spline"`` (default) and ``"nan"``
Expand Down
5 changes: 4 additions & 1 deletion mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,9 @@ def set_bad_epochs_to_NaN(self, bad_epochs_indices: list = None, verbose=None):
List of arrays with indices of bad epochs per channel.
verbose : bool, str, int, or None
"""
if not self.preload:
raise ValueError("Data must be preloaded.")

if len(bad_epochs_indices) != self.get_data().shape[1]:
raise RuntimeError(
"The length of the list of bad epochs indices "
Expand Down Expand Up @@ -1209,7 +1212,7 @@ def _compute_aggregate(self, picks, mode="mean"):
n_events += 1

if n_events > 0:
data = np.nanmean(data)
data /= n_events
else:
data.fill(np.nan)

Expand Down
16 changes: 16 additions & 0 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5240,3 +5240,19 @@ def test_empty_error(method, epochs_empty):
pytest.importorskip("pandas")
with pytest.raises(RuntimeError, match="is empty."):
getattr(epochs_empty.copy(), method[0])(**method[1])


def test_set_bad_epochs_to_nan():
"""Test channel specific epoch rejection."""
# preload=False
raw, ev, _ = _get_data(preload=False)
ep = Epochs(raw, ev, tmin=0, tmax=0.1, baseline=(0, 0))
bads = [[]] * ep.info["nchan"]
bads[0] = [1]
with pytest.raises(ValueError, match="must be preloaded"):
ep.set_bad_epochs_to_NaN(bads)

# preload=True
ep.load_data()
ep.set_bad_epochs_to_NaN(bads)
_ = ep.average()
4 changes: 2 additions & 2 deletions mne/utils/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,12 +919,12 @@ def _check_combine(mode, valid=("mean", "median", "std"), axis=0):
if mode == "mean":

def fun(data):
return np.mean(data, axis=axis)
return np.nanmean(data, axis=axis)

elif mode == "std":

def fun(data):
return np.std(data, axis=axis)
return np.nanstd(data, axis=axis)

elif mode == "median" or mode == np.median:

Expand Down

0 comments on commit caac9f4

Please sign in to comment.