Skip to content

Commit

Permalink
DW initial revisions
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikwelke committed Feb 21, 2024
1 parent b011e6a commit c794710
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 5 deletions.
4 changes: 2 additions & 2 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,8 +836,8 @@ def interpolate_bads(
method : dict | str | None
Method to use for each channel type.
- ``"meg"`` channels support ``"MNE"`` (default) and ``"nan"
- ``"eeg"`` channels support ``"spline"`` (default), ``"MNE"`` and ``"nan"
- ``"meg"`` channels support ``"MNE"`` (default) and ``"nan"``
- ``"eeg"`` channels support ``"spline"`` (default), ``"MNE"`` and ``"nan"``
- ``"fnirs"`` channels support ``"nearest"`` (default) and ``"nan"``
None is an alias for::
Expand Down
5 changes: 4 additions & 1 deletion mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,9 @@ def set_bad_epochs_to_NaN(self, bad_epochs_indices: list = None, verbose=None):
List of arrays with indices of bad epochs per channel.
verbose : bool, str, int, or None
"""
if not self.preload:
raise ValueError("Data must be preloaded.")

if len(bad_epochs_indices) != self.get_data().shape[1]:
raise RuntimeError(
"The length of the list of bad epochs indices "
Expand Down Expand Up @@ -1209,7 +1212,7 @@ def _compute_aggregate(self, picks, mode="mean"):
n_events += 1

if n_events > 0:
data = np.nanmean(data)
data /= n_events
else:
data.fill(np.nan)

Expand Down
16 changes: 16 additions & 0 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5204,3 +5204,19 @@ def test_empty_error(method, epochs_empty):
pytest.importorskip("pandas")
with pytest.raises(RuntimeError, match="is empty."):
getattr(epochs_empty.copy(), method[0])(**method[1])


def test_set_bad_epochs_to_nan():
"""Test channel specific epoch rejection."""
# preload=False
raw, ev, _ = _get_data(preload=False)
ep = Epochs(raw, ev, tmin=0, tmax=0.1, baseline=(0, 0))
bads = [[]] * ep.info["nchan"]
bads[0] = [1]
with pytest.raises(ValueError, match="must be preloaded"):
ep.set_bad_epochs_to_NaN(bads)

# preload=True
ep.load_data()
ep.set_bad_epochs_to_NaN(bads)
_ = ep.average()
4 changes: 2 additions & 2 deletions mne/utils/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,12 +901,12 @@ def _check_combine(mode, valid=("mean", "median", "std"), axis=0):
if mode == "mean":

def fun(data):
return np.mean(data, axis=axis)
return np.nanmean(data, axis=axis)

elif mode == "std":

def fun(data):
return np.std(data, axis=axis)
return np.nanstd(data, axis=axis)

elif mode == "median" or mode == np.median:

Expand Down

0 comments on commit c794710

Please sign in to comment.