Skip to content

Commit

Permalink
[ENH, MRG] Allow epoch construction from annotations (#12311)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Richard Höchenberger <[email protected]>
  • Loading branch information
3 people authored Dec 31, 2023
1 parent 6790426 commit c73b8af
Show file tree
Hide file tree
Showing 13 changed files with 109 additions and 55 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/12311.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:class:`mne.Epochs` can now be constructed using :class:`mne.Annotations` stored in the ``raw`` object, by specifying ``events=None``. By `Alex Rockhill`_.
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,6 +1310,7 @@ def reset_warnings(gallery_conf, fname):
for key in (
"invalid version and will not be supported", # pyxdf
"distutils Version classes are deprecated", # seaborn and neo
"is_categorical_dtype is deprecated", # seaborn
"`np.object` is a deprecated alias for the builtin `object`", # pyxdf
# nilearn, should be fixed in > 0.9.1
"In future, it will be an error for 'np.bool_' scalars to",
Expand Down
13 changes: 5 additions & 8 deletions examples/decoding/decoding_csp_eeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from sklearn.model_selection import ShuffleSplit, cross_val_score
from sklearn.pipeline import Pipeline

from mne import Epochs, events_from_annotations, pick_types
from mne import Epochs, pick_types
from mne.channels import make_standard_montage
from mne.datasets import eegbci
from mne.decoding import CSP
Expand All @@ -41,7 +41,6 @@
# avoid classification of evoked responses by using epochs that start 1s after
# cue onset.
tmin, tmax = -1.0, 4.0
event_id = dict(hands=2, feet=3)
subject = 1
runs = [6, 10, 14] # motor imagery: hands vs feet

Expand All @@ -50,22 +49,20 @@
eegbci.standardize(raw) # set channel names
montage = make_standard_montage("standard_1005")
raw.set_montage(montage)
raw.annotations.rename(dict(T1="hands", T2="feet"))

# Apply band-pass filter
raw.filter(7.0, 30.0, fir_design="firwin", skip_by_annotation="edge")

events, _ = events_from_annotations(raw, event_id=dict(T1=2, T2=3))

picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads")

# Read epochs (train will be done only between 1 and 2s)
# Testing will be done with a running classifier
epochs = Epochs(
raw,
events,
event_id,
tmin,
tmax,
event_id=["hands", "feet"],
tmin=tmin,
tmax=tmax,
proj=True,
picks=picks,
baseline=None,
Expand Down
19 changes: 8 additions & 11 deletions examples/decoding/decoding_csp_timefreq.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,22 @@
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import LabelEncoder

from mne import Epochs, create_info, events_from_annotations
from mne import Epochs, create_info
from mne.datasets import eegbci
from mne.decoding import CSP
from mne.io import concatenate_raws, read_raw_edf
from mne.time_frequency import AverageTFR

# %%
# Set parameters and read data
event_id = dict(hands=2, feet=3) # motor imagery: hands vs feet
subject = 1
runs = [6, 10, 14]
raw_fnames = eegbci.load_data(subject, runs)
raw = concatenate_raws([read_raw_edf(f) for f in raw_fnames])
raw.annotations.rename(dict(T1="hands", T2="feet"))

# Extract information from the raw file
sfreq = raw.info["sfreq"]
events, _ = events_from_annotations(raw, event_id=dict(T1=2, T2=3))
raw.pick(picks="eeg", exclude="bads")
raw.load_data()

Expand Down Expand Up @@ -95,10 +94,9 @@
# Extract epochs from filtered data, padded by window size
epochs = Epochs(
raw_filter,
events,
event_id,
tmin - w_size,
tmax + w_size,
event_id=["hands", "feet"],
tmin=tmin - w_size,
tmax=tmax + w_size,
proj=False,
baseline=None,
preload=True,
Expand Down Expand Up @@ -148,10 +146,9 @@
# Extract epochs from filtered data, padded by window size
epochs = Epochs(
raw_filter,
events,
event_id,
tmin - w_size,
tmax + w_size,
event_id=["hands", "feet"],
tmin=tmin - w_size,
tmax=tmax + w_size,
proj=False,
baseline=None,
preload=True,
Expand Down
11 changes: 5 additions & 6 deletions examples/time_frequency/time_frequency_erds.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@
raw = concatenate_raws([read_raw_edf(f, preload=True) for f in fnames])

raw.rename_channels(lambda x: x.strip(".")) # remove dots from channel names

events, _ = mne.events_from_annotations(raw, event_id=dict(T1=2, T2=3))
# rename descriptions to be more easily interpretable
raw.annotations.rename(dict(T1="hands", T2="feet"))

# %%
# Now we can create 5-second epochs around events of interest.
Expand All @@ -64,10 +64,9 @@

epochs = mne.Epochs(
raw,
events,
event_ids,
tmin - 0.5,
tmax + 0.5,
event_id=["hands", "feet"],
tmin=tmin - 0.5,
tmax=tmax + 0.5,
picks=("C3", "Cz", "C4"),
baseline=None,
preload=True,
Expand Down
6 changes: 1 addition & 5 deletions examples/visualization/eyetracking_plot_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,8 @@

mne.preprocessing.eyetracking.interpolate_blinks(raw, interpolate_gaze=True)
raw.annotations.rename({"dvns": "natural"}) # more intuitive
event_ids = {"natural": 1}
events, event_dict = mne.events_from_annotations(raw, event_id=event_ids)

epochs = mne.Epochs(
raw, events=events, event_id=event_dict, tmin=0, tmax=20, baseline=None
)
epochs = mne.Epochs(raw, event_id=["natural"], tmin=0, tmax=20, baseline=None)


# %%
Expand Down
60 changes: 54 additions & 6 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
EpochAnnotationsMixin,
_read_annotations_fif,
_write_annotations,
events_from_annotations,
)
from .baseline import _check_baseline, _log_rescale, rescale
from .bem import _check_origin
Expand Down Expand Up @@ -487,10 +488,7 @@ def __init__(
if events is not None: # RtEpochs can have events=None
for key, val in self.event_id.items():
if val not in events[:, 2]:
msg = "No matching events found for %s " "(event id %i)" % (
key,
val,
)
msg = f"No matching events found for {key} (event id {val})"
_on_missing(on_missing, msg)

# ensure metadata matches original events size
Expand Down Expand Up @@ -3104,14 +3102,57 @@ def _ensure_list(x):
return metadata, events, event_id


def _events_from_annotations(raw, events, event_id, annotations, on_missing):
"""Generate events and event_ids from annotations."""
events, event_id_tmp = events_from_annotations(raw)
if events.size == 0:
raise RuntimeError(
"No usable annotations found in the raw object. "
"Either `events` must be provided or the raw "
"object must have annotations to construct epochs"
)
if any(raw.annotations.duration > 0):
logger.info(
"Ignoring annotation durations and creating fixed-duration epochs "
"around annotation onsets."
)
if event_id is None:
event_id = event_id_tmp
# if event_id is the names of events, map to events integers
if isinstance(event_id, str):
event_id = [event_id]
if isinstance(event_id, (list, tuple, set)):
if not set(event_id).issubset(set(event_id_tmp)):
msg = (
"No matching annotations found for event_id(s) "
f"{set(event_id) - set(event_id_tmp)}"
)
_on_missing(on_missing, msg)
# remove extras if on_missing not error
event_id = set(event_id) & set(event_id_tmp)
event_id = {my_id: event_id_tmp[my_id] for my_id in event_id}
# remove any non-selected annotations
annotations.delete(~np.isin(raw.annotations.description, list(event_id)))
return events, event_id, annotations


@fill_doc
class Epochs(BaseEpochs):
"""Epochs extracted from a Raw instance.
Parameters
----------
%(raw_epochs)s
.. note::
If ``raw`` contains annotations, ``Epochs`` can be constructed around
``raw.annotations.onset``, but note that the durations of the annotations
are ignored in this case.
%(events_epochs)s
.. versionchanged:: 1.7
Allow ``events=None`` to use ``raw.annotations.onset`` as the source of
epoch times.
%(event_id)s
%(epochs_tmin_tmax)s
%(baseline_epochs)s
Expand Down Expand Up @@ -3212,7 +3253,7 @@ class Epochs(BaseEpochs):
def __init__(
self,
raw,
events,
events=None,
event_id=None,
tmin=-0.2,
tmax=0.5,
Expand Down Expand Up @@ -3240,6 +3281,7 @@ def __init__(
"instance of mne.io.BaseRaw"
)
info = deepcopy(raw.info)
annotations = raw.annotations.copy()

# proj is on when applied in Raw
proj = proj or raw.proj
Expand All @@ -3249,6 +3291,12 @@ def __init__(
# keep track of original sfreq (needed for annotations)
raw_sfreq = raw.info["sfreq"]

# get events from annotations if no events given
if events is None:
events, event_id, annotations = _events_from_annotations(
raw, events, event_id, annotations, on_missing
)

# call BaseEpochs constructor
super(Epochs, self).__init__(
info,
Expand All @@ -3273,7 +3321,7 @@ def __init__(
event_repeated=event_repeated,
verbose=verbose,
raw_sfreq=raw_sfreq,
annotations=raw.annotations,
annotations=annotations,
)

@verbose
Expand Down
20 changes: 20 additions & 0 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,26 @@ def test_filter(tmp_path):
assert_allclose(epochs.get_data(), data_filt, atol=1e-17)


def test_epochs_from_annotations():
"""Test epoch instantiation using annotations."""
raw, events = _get_data()[:2]
with pytest.raises(
RuntimeError, match="No usable annotations found in the raw object"
):
Epochs(raw)
raw.set_annotations(
mne.annotations_from_events(
events, raw.info["sfreq"], first_samp=raw.first_samp
)
)
# test on_missing
with pytest.raises(ValueError, match="No matching annotations"):
Epochs(raw, event_id="foo")
# test on_missing warn
with pytest.warns(match="No matching annotations"):
Epochs(raw, event_id=["1", "foo"], on_missing="warn")


def test_epochs_hash():
"""Test epoch hashing."""
raw, events = _get_data()[:2]
Expand Down
8 changes: 5 additions & 3 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,12 +1107,14 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
"""

docdict["event_id"] = """
event_id : int | list of int | dict | None
event_id : int | list of int | dict | str | list of str | None
The id of the :term:`events` to consider. If dict, the keys can later be
used to access associated :term:`events`. Example:
dict(auditory=1, visual=3). If int, a dict will be created with the id as
string. If a list, all :term:`events` with the IDs specified in the list
are used. If None, all :term:`events` will be used and a dict is created
string. If a list of int, all :term:`events` with the IDs specified in the list
are used. If a str or list of str, ``events`` must be ``None`` to use annotations
and then the IDs must be the name(s) of the annotations to use.
If None, all :term:`events` will be used and a dict is created
with string integer names corresponding to the event id integers."""

docdict["event_id_ecg"] = """
Expand Down
2 changes: 1 addition & 1 deletion tools/setup_xvfb.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ done

# This also includes the libraries necessary for PyQt5/PyQt6
sudo apt update
sudo apt install -yqq xvfb libxkbcommon-x11-0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0 libxcb-render-util0 libxcb-xinerama0 libxcb-xfixes0 libopengl0 libegl1 libosmesa6 mesa-utils libxcb-shape0 libxcb-cursor0
sudo apt install -yqq xvfb libxkbcommon-x11-0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0 libxcb-render-util0 libxcb-xinerama0 libxcb-xfixes0 libopengl0 libegl1 libosmesa6 mesa-utils libxcb-shape0 libxcb-cursor0 libxml2
/sbin/start-stop-daemon --start --quiet --pidfile /tmp/custom_xvfb_99.pid --make-pidfile --background --exec /usr/bin/Xvfb -- :99 -screen 0 1400x900x24 -ac +extension GLX +render -noreset
3 changes: 1 addition & 2 deletions tutorials/clinical/20_seeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@

raw = mne.io.read_raw(misc_path / "seeg" / "sample_seeg_ieeg.fif")

events, event_id = mne.events_from_annotations(raw)
epochs = mne.Epochs(raw, events, event_id, detrend=1, baseline=None)
epochs = mne.Epochs(raw, detrend=1, baseline=None)
epochs = epochs["Response"][0] # just process one epoch of data for speed

# %%
Expand Down
6 changes: 1 addition & 5 deletions tutorials/clinical/30_ecog.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,11 @@
# at the posterior commissure)
raw.set_montage(montage)

# Find the annotated events
events, event_id = mne.events_from_annotations(raw)

# Make a 25 second epoch that spans before and after the seizure onset
epoch_length = 25 # seconds
epochs = mne.Epochs(
raw,
events,
event_id=event_id["onset"],
event_id="onset",
tmin=13,
tmax=13 + epoch_length,
baseline=None,
Expand Down
14 changes: 6 additions & 8 deletions tutorials/time-freq/50_ssvep.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,12 @@
raw.filter(l_freq=0.1, h_freq=None, fir_design="firwin", verbose=False)

# Construct epochs
event_id = {"12hz": 255, "15hz": 155}
events, _ = mne.events_from_annotations(raw, verbose=False)
raw.annotations.rename({"Stimulus/S255": "12hz", "Stimulus/S155": "15hz"})
tmin, tmax = -1.0, 20.0 # in s
baseline = None
epochs = mne.Epochs(
raw,
events=events,
event_id=[event_id["12hz"], event_id["15hz"]],
event_id=["12hz", "15hz"],
tmin=tmin,
tmax=tmax,
baseline=baseline,
Expand Down Expand Up @@ -356,8 +354,8 @@ def snr_spectrum(psd, noise_n_neighbor_freqs=1, noise_skip_neighbor_freqs=1):
# Get indices for the different trial types
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

i_trial_12hz = np.where(epochs.events[:, 2] == event_id["12hz"])[0]
i_trial_15hz = np.where(epochs.events[:, 2] == event_id["15hz"])[0]
i_trial_12hz = np.where(epochs.annotations.description == "12hz")[0]
i_trial_15hz = np.where(epochs.annotations.description == "15hz")[0]

# %%
# Get indices of EEG channels forming the ROI
Expand Down Expand Up @@ -604,7 +602,7 @@ def snr_spectrum(psd, noise_n_neighbor_freqs=1, noise_skip_neighbor_freqs=1):
window_snrs = [[]] * len(window_lengths)
for i_win, win in enumerate(window_lengths):
# compute spectrogram
this_spectrum = epochs[str(event_id["12hz"])].compute_psd(
this_spectrum = epochs["12hz"].compute_psd(
"welch",
n_fft=int(sfreq * win),
n_overlap=0,
Expand Down Expand Up @@ -688,7 +686,7 @@ def snr_spectrum(psd, noise_n_neighbor_freqs=1, noise_skip_neighbor_freqs=1):

for i_win, win in enumerate(window_starts):
# compute spectrogram
this_spectrum = epochs[str(event_id["12hz"])].compute_psd(
this_spectrum = epochs["12hz"].compute_psd(
"welch",
n_fft=int(sfreq * window_length) - 1,
n_overlap=0,
Expand Down

0 comments on commit c73b8af

Please sign in to comment.