Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DW initial revisions #3

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,8 +836,8 @@ def interpolate_bads(
method : dict | str | None
Method to use for each channel type.

- ``"meg"`` channels support ``"MNE"`` (default) and ``"nan"
- ``"eeg"`` channels support ``"spline"`` (default), ``"MNE"`` and ``"nan"
- ``"meg"`` channels support ``"MNE"`` (default) and ``"nan"``
- ``"eeg"`` channels support ``"spline"`` (default), ``"MNE"`` and ``"nan"``
- ``"fnirs"`` channels support ``"nearest"`` (default) and ``"nan"``

None is an alias for::
Expand Down
5 changes: 4 additions & 1 deletion mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,9 @@ def set_bad_epochs_to_NaN(self, bad_epochs_indices: list = None, verbose=None):
List of arrays with indices of bad epochs per channel.
verbose : bool, str, int, or None
"""
if not self.preload:
raise ValueError("Data must be preloaded.")

if len(bad_epochs_indices) != self.get_data().shape[1]:
raise RuntimeError(
"The length of the list of bad epochs indices "
Expand Down Expand Up @@ -1209,7 +1212,7 @@ def _compute_aggregate(self, picks, mode="mean"):
n_events += 1

if n_events > 0:
data = np.nanmean(data)
data /= n_events
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still a bit confused about this part of the function, to be honest, maybe a comment would be good at that point

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this whole stuff only happens in the case of not-preloaded data, and thats why i guess it has to be somehow hacky and hard to follow ;)
in this case there is no numpy array of data in memory that we can simply access and change (probably only a view or so..)
so they create this mock data object, fill it up by some loop through the epochs object and then average that.

but that's also why it's not really relevant for us - i doubt that we could even set not-preloaded data to nan easily without larger changes (i assume this info would have to be stored in some meta data and then applied when finally loading at some later point?), so it doenst make sense to implement an averaging function for this case.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not super sure about the second part of my reply, maybe it is easier than i think. still, if we ask data to be preloaded, it doesnt apply

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okey, thank you for the explanation, now I finally understand what is going on.

else:
data.fill(np.nan)

Expand Down
16 changes: 16 additions & 0 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5204,3 +5204,19 @@ def test_empty_error(method, epochs_empty):
pytest.importorskip("pandas")
with pytest.raises(RuntimeError, match="is empty."):
getattr(epochs_empty.copy(), method[0])(**method[1])


def test_set_bad_epochs_to_nan():
"""Test channel specific epoch rejection."""
# preload=False
raw, ev, _ = _get_data(preload=False)
ep = Epochs(raw, ev, tmin=0, tmax=0.1, baseline=(0, 0))
bads = [[]] * ep.info["nchan"]
bads[0] = [1]
with pytest.raises(ValueError, match="must be preloaded"):
ep.set_bad_epochs_to_NaN(bads)

# preload=True
ep.load_data()
ep.set_bad_epochs_to_NaN(bads)
_ = ep.average()
4 changes: 2 additions & 2 deletions mne/utils/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,12 +901,12 @@ def _check_combine(mode, valid=("mean", "median", "std"), axis=0):
if mode == "mean":

def fun(data):
return np.mean(data, axis=axis)
return np.nanmean(data, axis=axis)

elif mode == "std":

def fun(data):
return np.std(data, axis=axis)
return np.nanstd(data, axis=axis)

elif mode == "median" or mode == np.median:

Expand Down
Loading