Skip to content

Commit

Permalink
Merge branch 'channel_specific_epoch_rejection' of github.com:CarinaF…
Browse files Browse the repository at this point in the history
…o/mne-python into channel_specific_epoch_rejection
  • Loading branch information
CarinaFo committed Feb 26, 2024
2 parents 7feea14 + f2ea3c4 commit 1856286
Show file tree
Hide file tree
Showing 33 changed files with 435 additions and 158 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
repos:
# Ruff mne
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.1
rev: v0.2.2
hooks:
- id: ruff
name: ruff lint mne
Expand Down Expand Up @@ -32,7 +32,7 @@ repos:

# yamllint
- repo: https://github.com/adrienverge/yamllint.git
rev: v1.34.0
rev: v1.35.1
hooks:
- id: yamllint
args: [--strict, -c, .yamllint.yml]
Expand Down
4 changes: 2 additions & 2 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ stages:
- bash: |
set -e
python -m pip install --progress-bar off --upgrade pip
python -m pip install --progress-bar off "mne-qt-browser[opengl] @ git+https://github.com/mne-tools/mne-qt-browser.git@main" pyvista scikit-learn pytest-error-for-skips python-picard "PyQt6!=6.5.1,!=6.6.1" "PyQt6-Qt6!=6.6.1" qtpy nibabel sphinx-gallery
python -m pip install --progress-bar off "mne-qt-browser[opengl] @ git+https://github.com/mne-tools/mne-qt-browser.git@main" pyvista scikit-learn pytest-error-for-skips python-picard "PyQt6!=6.5.1,!=6.6.1,!=6.6.2" "PyQt6-Qt6!=6.6.1,!=6.6.2" qtpy nibabel sphinx-gallery
python -m pip uninstall -yq mne
python -m pip install --progress-bar off --upgrade -e .[test]
displayName: 'Install dependencies with pip'
Expand Down Expand Up @@ -183,7 +183,7 @@ stages:
displayName: 'Get test data'
- bash: |
set -e
python -m pip install "PyQt6!=6.6.1" "PyQt6-Qt6!=6.6.1"
python -m pip install "PyQt6!=6.6.1,!=6.6.2" "PyQt6-Qt6!=6.6.1,!=6.6.2"
LD_DEBUG=libs python -c "from PyQt6.QtWidgets import QApplication, QWidget; app = QApplication([]); import matplotlib; matplotlib.use('QtAgg'); import matplotlib.pyplot as plt; plt.figure()"
- bash: |
mne sys_info -pd
Expand Down
1 change: 1 addition & 0 deletions doc/changes/devel/12443.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add option to pass ``image_kwargs`` to :class:`mne.Report.add_epochs` to allow adjusting e.g. ``vmin`` and ``vmax`` of the epochs image in the report, by `Sophie Herbst`_.
1 change: 1 addition & 0 deletions doc/changes/devel/12444.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix validation of ``ch_type`` in :func:`mne.preprocessing.annotate_muscle_zscore`, by `Mathieu Scheltienne`_.
1 change: 1 addition & 0 deletions doc/changes/devel/12445.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for multiple raw instances in :func:`mne.preprocessing.compute_average_dev_head_t` by `Eric Larson`_.
1 change: 1 addition & 0 deletions doc/changes/devel/12446.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support partial pathlength factors for each wavelength in :func:`mne.preprocessing.nirs.beer_lambert_law`, by :newcontrib:`Richard Scholz`.
1 change: 1 addition & 0 deletions doc/changes/devel/12451.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix errant redundant use of ``BIDSPath.split`` when writing split raw and epochs data, by `Eric Larson`_.
1 change: 1 addition & 0 deletions doc/changes/devel/12451.dependency.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
``pytest-harvest`` is no longer used as a test dependency, by `Eric Larson`_.
1 change: 1 addition & 0 deletions doc/changes/devel/12454.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Completing PR 12453. Add option to pass ``image_kwargs`` per channel type to :class:`mne.Report.add_epochs`.
1 change: 1 addition & 0 deletions doc/changes/devel/12456.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Disable config parser interpolation when reading BrainVision files, which allows using the percent sign as a regular character in channel units, by `Clemens Brunner`_.
2 changes: 2 additions & 0 deletions doc/changes/names.inc
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,8 @@

.. _Richard Koehler: https://github.com/richardkoehler

.. _Richard Scholz: https://github.com/scholzri

.. _Riessarius Stargardsky: https://github.com/Riessarius

.. _Roan LaPlante: https://github.com/aestrivex
Expand Down
2 changes: 2 additions & 0 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,8 @@ def interpolate_bads(
- ``"meg"`` channels support ``"MNE"`` (default) and ``"nan"
- ``"eeg"`` channels support ``"spline"`` (default), ``"MNE"`` and ``"nan"
- ``"fnirs"`` channels support ``"nearest"`` (default) and ``"nan"``
- ``"ecog"`` channels support ``"spline"`` (default) and ``"nan"``
- ``"seeg"`` channels support ``"spline"`` (default) and ``"nan"``
None is an alias for::
Expand Down
31 changes: 11 additions & 20 deletions mne/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import shutil
import sys
import warnings
from collections import defaultdict
from contextlib import contextmanager
from pathlib import Path
from textwrap import dedent
Expand Down Expand Up @@ -900,11 +901,8 @@ def protect_config():


def _test_passed(request):
try:
outcome = request.node.harvest_rep_call
except Exception:
outcome = "passed"
return outcome == "passed"
report = request.node.stash[_phase_report_key]
return "call" in report and report["call"].outcome == "passed"


@pytest.fixture()
Expand All @@ -931,7 +929,6 @@ def brain_gc(request):
ignore = set(id(o) for o in gc.get_objects())
yield
close_func()
# no need to warn if the test itself failed, pytest-harvest helps us here
if not _test_passed(request):
return
_assert_no_instances(Brain, "after")
Expand Down Expand Up @@ -960,16 +957,12 @@ def pytest_sessionfinish(session, exitstatus):
if n is None:
return
print("\n")
try:
import pytest_harvest
except ImportError:
print("Module-level timings require pytest-harvest")
return
# get the number to print
res = pytest_harvest.get_session_synthesis_dct(session)
files = dict()
for key, val in res.items():
parts = Path(key.split(":")[0]).parts
files = defaultdict(lambda: 0.0)
for item in session.items:
report = item.stash[_phase_report_key]
dur = sum(x.duration for x in report.values())
parts = Path(item.nodeid.split(":")[0]).parts
# split mne/tests/test_whatever.py into separate categories since these
# are essentially submodule-level tests. Keeping just [:3] works,
# except for mne/viz where we want level-4 granulatity
Expand All @@ -978,7 +971,7 @@ def pytest_sessionfinish(session, exitstatus):
if not parts[-1].endswith(".py"):
parts = parts + ("",)
file_key = "/".join(parts)
files[file_key] = files.get(file_key, 0) + val["pytest_duration_s"]
files[file_key] += dur
files = sorted(list(files.items()), key=lambda x: x[1])[::-1]
# print
_files[:] = files[:n]
Expand All @@ -999,7 +992,7 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
writer.line(f"{timing.ljust(15)}{name}")


def pytest_report_header(config, startdir):
def pytest_report_header(config, startdir=None):
"""Add information to the pytest run header."""
return f"MNE {mne.__version__} -- {str(Path(mne.__file__).parent)}"

Expand Down Expand Up @@ -1122,7 +1115,6 @@ def run(nbexec=nbexec, code=code):
return


@pytest.mark.filterwarnings("ignore:.*Extraction of measurement.*:")
@pytest.fixture(
params=(
[nirsport2, nirsport2_snirf, testing._pytest_param()],
Expand Down Expand Up @@ -1160,8 +1152,7 @@ def qt_windows_closed(request):
if "allow_unclosed_pyside2" in marks and API_NAME.lower() == "pyside2":
return
# Don't check when the test fails
report = request.node.stash[_phase_report_key]
if ("call" not in report) or report["call"].failed:
if not _test_passed(request):
return
widgets = app.topLevelWidgets()
n_after = len(widgets)
Expand Down
2 changes: 1 addition & 1 deletion mne/datasets/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ from . import (
)
from ._fetch import fetch_dataset
from ._fsaverage.base import fetch_fsaverage
from ._infant.base import fetch_infant_template
from ._infant import fetch_infant_template
from ._phantom.base import fetch_phantom
from .utils import (
_download_all_example_data,
Expand Down
1 change: 1 addition & 0 deletions mne/datasets/_infant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .base import fetch_infant_template
9 changes: 8 additions & 1 deletion mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2240,7 +2240,14 @@ def save(
)

# check for file existence and expand `~` if present
fname = str(_check_fname(fname=fname, overwrite=overwrite))
fname = str(
_check_fname(
fname=fname,
overwrite=overwrite,
check_bids_split=True,
name="fname",
)
)

split_size_bytes = _get_split_size(split_size)

Expand Down
5 changes: 2 additions & 3 deletions mne/fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@
###############################################################################
# distutils

# distutils has been deprecated since Python 3.10 and is scheduled for removal
# from the standard library with the release of Python 3.12. For version
# comparisons, we use setuptools's `parse_version` if available.
# distutils has been deprecated since Python 3.10 and was removed
# from the standard library with the release of Python 3.12.


def _compare_version(version_a, operator, version_b):
Expand Down
8 changes: 7 additions & 1 deletion mne/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1694,7 +1694,13 @@ def save(
endings_err = (".fif", ".fif.gz")

# convert to str, check for overwrite a few lines later
fname = _check_fname(fname, overwrite=True, verbose="error")
fname = _check_fname(
fname,
overwrite=True,
verbose="error",
check_bids_split=True,
name="fname",
)
check_fname(fname, "raw", endings, endings_err=endings_err)

split_size = _get_split_size(split_size)
Expand Down
2 changes: 1 addition & 1 deletion mne/io/brainvision/brainvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def _aux_hdr_info(hdr_fname):
params, settings = settings.split("[Comment]")
else:
params, settings = settings, ""
cfg = configparser.ConfigParser()
cfg = configparser.ConfigParser(interpolation=None)
with StringIO(params) as fid:
cfg.read_file(fid)

Expand Down
28 changes: 28 additions & 0 deletions mne/io/fiff/tests/test_raw_fiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,34 @@ def test_split_files(tmp_path, mod, monkeypatch):
assert not fname_3.is_file()


def test_bids_split_files(tmp_path):
"""Test that BIDS split files are written safely."""
mne_bids = pytest.importorskip("mne_bids")
bids_path = mne_bids.BIDSPath(
root=tmp_path,
subject="01",
datatype="meg",
split="01",
suffix="raw",
extension=".fif",
check=False,
)
(tmp_path / "sub-01" / "meg").mkdir(parents=True)
raw = read_raw_fif(test_fif_fname)
save_kwargs = dict(
buffer_size_sec=1.0, split_size="10MB", split_naming="bids", verbose=True
)
with pytest.raises(ValueError, match="Passing a BIDSPath"):
raw.save(bids_path, **save_kwargs)
bids_path.split = None
want_paths = [Path(bids_path.copy().update(split=ii).fpath) for ii in range(1, 3)]
for want_path in want_paths:
assert not want_path.is_file()
raw.save(bids_path, **save_kwargs)
for want_path in want_paths:
assert want_path.is_file()


def _err(*args, **kwargs):
raise RuntimeError("Killed mid-write")

Expand Down
90 changes: 62 additions & 28 deletions mne/preprocessing/artifact_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,15 @@
apply_trans,
quat_to_rot,
)
from ..utils import _mask_to_onsets_offsets, _pl, _validate_type, logger, verbose
from ..utils import (
_check_option,
_mask_to_onsets_offsets,
_pl,
_validate_type,
logger,
verbose,
warn,
)


@verbose
Expand Down Expand Up @@ -94,16 +102,13 @@ def annotate_muscle_zscore(
ch_type = "eeg"
else:
raise ValueError(
"No M/EEG channel types found, please specify a"
" ch_type or provide M/EEG sensor data"
"No M/EEG channel types found, please specify a 'ch_type' or provide "
"M/EEG sensor data."
)
logger.info("Using %s sensors for muscle artifact detection" % (ch_type))

if ch_type in ("mag", "grad"):
raw_copy.pick(ch_type)
logger.info("Using %s sensors for muscle artifact detection", ch_type)
else:
ch_type = {"meg": False, ch_type: True}
raw_copy.pick(**ch_type)
_check_option("ch_type", ch_type, ["mag", "grad", "eeg"])
raw_copy.pick(ch_type)

raw_copy.filter(
filter_freq[0],
Expand Down Expand Up @@ -289,27 +294,68 @@ def annotate_movement(
return annot, disp


def compute_average_dev_head_t(raw, pos):
@verbose
def compute_average_dev_head_t(raw, pos, *, verbose=None):
"""Get new device to head transform based on good segments.
Segments starting with "BAD" annotations are not included for calculating
the mean head position.
Parameters
----------
raw : instance of Raw
Data to compute head position.
pos : array, shape (N, 10)
The position and quaternion parameters from cHPI fitting.
raw : instance of Raw | list of Raw
Data to compute head position. Can be a list containing multiple raw
instances.
pos : array, shape (N, 10) | list of ndarray
The position and quaternion parameters from cHPI fitting. Can be
a list containing multiple position arrays, one per raw instance passed.
%(verbose)s
Returns
-------
dev_head_t : instance of Transform
New ``dev_head_t`` transformation using the averaged good head positions.
Notes
-----
.. versionchanged:: 1.7
Support for multiple raw instances and position arrays was added.
"""
# Get weighted head pos trans and rot
if not isinstance(raw, (list, tuple)):
raw = [raw]
if not isinstance(pos, (list, tuple)):
pos = [pos]
if len(pos) != len(raw):
raise ValueError(
f"Number of head positions ({len(pos)}) must match the number of raw "
f"instances ({len(raw)})"
)
hp = list()
dt = list()
for ri, (r, p) in enumerate(zip(raw, pos)):
_validate_type(r, BaseRaw, f"raw[{ri}]")
_validate_type(p, np.ndarray, f"pos[{ri}]")
hp_, dt_ = _raw_hp_weights(r, p)
hp.append(hp_)
dt.append(dt_)
hp = np.concatenate(hp, axis=0)
dt = np.concatenate(dt, axis=0)
dt /= dt.sum()
best_q = _average_quats(hp[:, 1:4], weights=dt)
trans = np.eye(4)
trans[:3, :3] = quat_to_rot(best_q)
trans[:3, 3] = dt @ hp[:, 4:7]
dist = np.linalg.norm(trans[:3, 3])
if dist > 1: # less than 1 meter is sane
warn(f"Implausible head position detected: {dist} meters from device origin")
dev_head_t = Transform("meg", "head", trans)
return dev_head_t


def _raw_hp_weights(raw, pos):
sfreq = raw.info["sfreq"]
seg_good = np.ones(len(raw.times))
trans_pos = np.zeros(3)
hp = pos.copy()
hp_ts = hp[:, 0] - raw._first_time

Expand Down Expand Up @@ -349,19 +395,7 @@ def compute_average_dev_head_t(raw, pos):
assert (dt >= 0).all()
dt = dt / sfreq
del seg_good, idx

# Get weighted head pos trans and rot
trans_pos += np.dot(dt, hp[:, 4:7])

rot_qs = hp[:, 1:4]
best_q = _average_quats(rot_qs, weights=dt)

trans = np.eye(4)
trans[:3, :3] = quat_to_rot(best_q)
trans[:3, 3] = trans_pos / dt.sum()
assert np.linalg.norm(trans[:3, 3]) < 1 # less than 1 meter is sane
dev_head_t = Transform("meg", "head", trans)
return dev_head_t
return hp, dt


def _annotations_from_mask(times, mask, annot_name, orig_time=None):
Expand Down
Loading

0 comments on commit 1856286

Please sign in to comment.