From 0aae72dda030242d992866ce6d551206749474b2 Mon Sep 17 00:00:00 2001 From: Alex lepauvre Date: Fri, 23 Aug 2024 05:48:50 +0200 Subject: [PATCH] Sensor scales (#12805) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Larson --- doc/changes/devel/12805.newfeature.rst | 1 + doc/changes/names.inc | 1 + mne/decoding/tests/test_base.py | 1 + mne/utils/docs.py | 13 +++ mne/viz/_3d.py | 76 +++++++++++- mne/viz/_brain/_brain.py | 5 + mne/viz/_brain/tests/test_brain.py | 111 ++++++++++++++++++ mne/viz/tests/test_3d.py | 156 +++++++++++++++++++++++++ 8 files changed, 359 insertions(+), 5 deletions(-) create mode 100644 doc/changes/devel/12805.newfeature.rst diff --git a/doc/changes/devel/12805.newfeature.rst b/doc/changes/devel/12805.newfeature.rst new file mode 100644 index 00000000000..2c77d55d3ba --- /dev/null +++ b/doc/changes/devel/12805.newfeature.rst @@ -0,0 +1 @@ +Added support for ``sensor_scales`` to :meth:`mne.viz.Brain.add_sensors` and :func:`mne.viz.plot_alignment`, by :newcontrib:`Alex Lepauvre`. \ No newline at end of file diff --git a/doc/changes/names.inc b/doc/changes/names.inc index d8e33e02d3d..0939ac8b29b 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -10,6 +10,7 @@ .. _Alex Ciok: https://github.com/alexCiok .. _Alex Gramfort: https://alexandre.gramfort.net .. _Alex Kiefer: https://home.alexk101.dev +.. _Alex Lepauvre: https://github.com/AlexLepauvre .. _Alex Rockhill: https://github.com/alexrockhill/ .. _Alexander Rudiuk: https://github.com/ARudiuk .. _Alexandre Barachant: https://alexandre.barachant.org diff --git a/mne/decoding/tests/test_base.py b/mne/decoding/tests/test_base.py index c1a992c416c..10d9950bbf7 100644 --- a/mne/decoding/tests/test_base.py +++ b/mne/decoding/tests/test_base.py @@ -67,6 +67,7 @@ def _make_data(n_samples=1000, n_features=5, n_targets=3): return X, Y, A +@pytest.mark.filterwarnings("ignore:invalid value encountered in cast.*:RuntimeWarning") def test_get_coef(): """Test getting linear coefficients (filters/patterns) from estimators.""" from sklearn import svm diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 0fa9288bec2..3e59d751d93 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -3997,6 +3997,19 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): shape ``(n_eeg, 3)`` or ``(n_eeg, 4)``. """ +docdict["sensor_scales"] = """ +sensor_scales : int | float | array-like | dict | None + Scale to use for the sensor glyphs. Can be None (default) to use default scale. + A dict should provide the Scale (values) for each channel type (keys), e.g.:: + + dict(eeg=eeg_scales) + + Where the value (``eeg_scales`` above) can be broadcast to an array of values with + length that matches the number of channels of that type. A few examples of this + for the case above are the value ``10e-3``, a list of ``n_eeg`` values, or an NumPy + ndarray of shape ``(n_eeg,)``. +""" + docdict["sensors_topomap"] = """ sensors : bool | str Whether to add markers for sensor locations. If :class:`str`, should be a diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py index a27baae16e0..7dbabfa1ef6 100644 --- a/mne/viz/_3d.py +++ b/mne/viz/_3d.py @@ -526,6 +526,8 @@ def plot_alignment( fig=None, interaction="terrain", sensor_colors=None, + *, + sensor_scales=None, verbose=None, ): """Plot head, sensor, and source space alignment in 3D. @@ -617,6 +619,9 @@ def plot_alignment( .. versionchanged:: 1.6 Support for passing a ``dict`` was added. + %(sensor_scales)s + + .. versionadded:: 1.9 %(verbose)s Returns @@ -906,6 +911,7 @@ def plot_alignment( "m", sensor_alpha=sensor_alpha, sensor_colors=sensor_colors, + sensor_scales=sensor_scales, ) if src is not None: @@ -1480,6 +1486,7 @@ def _plot_sensors_3d( check_inside=None, nearest=None, sensor_colors=None, + sensor_scales=None, ): """Render sensors in a 3D scene.""" from matplotlib.colors import to_rgba_array @@ -1535,23 +1542,44 @@ def _plot_sensors_3d( sensor_colors = { list(locs)[0]: to_rgba_array(sensor_colors), } + if sensor_scales is not None and not isinstance(sensor_scales, dict): + sensor_scales = { + list(locs)[0]: sensor_scales, + } else: extra = f"when more than one channel type ({list(locs)}) is plotted" _validate_type(sensor_colors, types, "sensor_colors", extra=extra) + _validate_type(sensor_scales, types, "sensor_scales", extra=extra) del extra, types if sensor_colors is None: sensor_colors = dict() + if sensor_scales is None: + sensor_scales = dict() assert isinstance(sensor_colors, dict) + assert isinstance(sensor_scales, dict) for ch_type, sens_loc in locs.items(): logger.debug(f"Drawing {ch_type} sensors") assert len(sens_loc) # should be guaranteed above colors = to_rgba_array(sensor_colors.get(ch_type, defaults[ch_type + "_color"])) + scales = np.atleast_1d( + sensor_scales.get(ch_type, defaults[ch_type + "_scale"] * unit_scalar) + ) _check_option( f"len(sensor_colors[{repr(ch_type)}])", colors.shape[0], (len(sens_loc), 1), ) - scale = defaults[ch_type + "_scale"] * unit_scalar + _check_option( + f"len(sensor_scales[{repr(ch_type)}])", + scales.shape[0], + (len(sens_loc), 1), + ) + # Check that the scale is numerical + assert np.issubdtype(scales.dtype, np.number), ( + f"scales for {ch_type} must contain only numerical values, " + f"got {scales} instead." + ) + this_alpha = sensor_alpha[ch_type] if isinstance(sens_loc[0], dict): # meg coil if len(colors) == 1: @@ -1567,13 +1595,13 @@ def _plot_sensors_3d( else: sens_loc = np.array(sens_loc, float) mask = ~np.isnan(sens_loc).any(axis=1) - if len(colors) == 1: + if len(colors) == 1 and len(scales) == 1: # Single color mode (one actor) actor, _ = _plot_glyphs( renderer=renderer, loc=sens_loc[mask] * unit_scalar, color=colors[0, :3], - scale=scale, + scale=scales[0], opacity=this_alpha * colors[0, 3], orient_glyphs=orient_glyphs, scale_by_distance=scale_by_distance, @@ -1583,9 +1611,47 @@ def _plot_sensors_3d( nearest=nearest, ) actors[ch_type].append(actor) - else: - # Multi-color mode (multiple actors) + elif len(colors) == len(sens_loc) and len(scales) == 1: + # Multi-color single scale mode (multiple actors) for loc, color, usable in zip(sens_loc, colors, mask): + if not usable: + continue + actor, _ = _plot_glyphs( + renderer=renderer, + loc=loc * unit_scalar, + color=color[:3], + scale=scales[0], + opacity=this_alpha * color[3], + orient_glyphs=orient_glyphs, + scale_by_distance=scale_by_distance, + project_points=project_points, + surf=surf, + check_inside=check_inside, + nearest=nearest, + ) + actors[ch_type].append(actor) + elif len(colors) == 1 and len(scales) == len(sens_loc): + # Multi-scale single color mode (multiple actors) + for loc, scale, usable in zip(sens_loc, scales, mask): + if not usable: + continue + actor, _ = _plot_glyphs( + renderer=renderer, + loc=loc * unit_scalar, + color=colors[0, :3], + scale=scale, + opacity=this_alpha * colors[0, 3], + orient_glyphs=orient_glyphs, + scale_by_distance=scale_by_distance, + project_points=project_points, + surf=surf, + check_inside=check_inside, + nearest=nearest, + ) + actors[ch_type].append(actor) + else: + # Multi-color multi-scale mode (multiple actors) + for loc, color, scale, usable in zip(sens_loc, colors, scales, mask): if not usable: continue actor, _ = _plot_glyphs( diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index b9e88230b97..247c0840858 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -2690,6 +2690,7 @@ def add_sensors( max_dist=0.004, *, sensor_colors=None, + sensor_scales=None, verbose=None, ): """Add mesh objects to represent sensor positions. @@ -2708,6 +2709,9 @@ def add_sensors( %(sensor_colors)s .. versionadded:: 1.6 + %(sensor_scales)s + + .. versionadded:: 1.9 %(verbose)s Notes @@ -2764,6 +2768,7 @@ def add_sensors( self._units, sensor_alpha=sensor_alpha, sensor_colors=sensor_colors, + sensor_scales=sensor_scales, ) # sensors_actors can still be None for item, actors in (sensors_actors or {}).items(): diff --git a/mne/viz/_brain/tests/test_brain.py b/mne/viz/_brain/tests/test_brain.py index 3963d3e085b..2a1c943250b 100644 --- a/mne/viz/_brain/tests/test_brain.py +++ b/mne/viz/_brain/tests/test_brain.py @@ -5,6 +5,7 @@ import os import platform +from contextlib import nullcontext from pathlib import Path from shutil import copyfile @@ -544,6 +545,116 @@ def __init__(self): brain.close() +# TODO: Figure out why brain_gc is problematic here on PyQt5 +@pytest.mark.allow_unclosed +@testing.requires_testing_data +@pytest.mark.parametrize( + "sensor_colors, sensor_scales, expectation", + [ + ( + {"seeg": ["k"] * 5}, + {"seeg": [2] * 6}, + pytest.raises( + ValueError, + match=r"Invalid value for the 'len\(sensor_colors\['seeg'\]\)' " + r"parameter. Allowed values are \d+ and \d+, but got \d+ instead", + ), + ), + ( + {"seeg": ["k"] * 6}, + {"seeg": [2] * 5}, + pytest.raises( + ValueError, + match=r"Invalid value for the 'len\(sensor_scales\['seeg'\]\)' " + r"parameter. Allowed values are \d+ and \d+, but got \d+ instead", + ), + ), + ( + "NotAColor", + 2, + pytest.raises( + ValueError, + match=r".* is not a valid color value", + ), + ), + ( + "k", + "k", + pytest.raises( + AssertionError, + match=r"scales for .* must contain only numerical values, got .* " + r"instead.", + ), + ), + ( + "k", + 2, + nullcontext(), + ), + ( + ["k"] * 6, + [2] * 6, + nullcontext(), + ), + ( + {"seeg": ["k"] * 6}, + {"seeg": [2] * 6}, + nullcontext(), + ), + ], +) +def test_add_sensors_scales( + renderer_interactive_pyvistaqt, + sensor_colors, + sensor_scales, + expectation, +): + """Test sensor_scales parameter.""" + kwargs = dict(subject=subject, subjects_dir=subjects_dir) + hemi = "lh" + surf = "white" + cortex = "low_contrast" + title = "test" + size = (300, 300) + + brain = Brain( + hemi=hemi, + surf=surf, + size=size, + title=title, + cortex=cortex, + units="m", + silhouette=dict(decimate=0.95), + **kwargs, + ) + + proj_info = create_info([f"Ch{i}" for i in range(1, 7)], 1000, "seeg") + pos = ( + np.array( + [ + [25.85, 9.04, -5.38], + [33.56, 9.04, -5.63], + [40.44, 9.04, -5.06], + [46.75, 9.04, -6.78], + [-30.08, 9.04, 28.23], + [-32.95, 9.04, 37.99], + ] + ) + / 1000 + ) + proj_info.set_montage( + make_dig_montage(ch_pos=dict(zip(proj_info.ch_names, pos)), coord_frame="head") + ) + with expectation: + brain.add_sensors( + proj_info, + trans=fname_trans, + sensor_colors=sensor_colors, + sensor_scales=sensor_scales, + ) + brain.close() + + def _assert_view_allclose( brain, roll, diff --git a/mne/viz/tests/test_3d.py b/mne/viz/tests/test_3d.py index 24f6718a5a2..9c398f94dec 100644 --- a/mne/viz/tests/test_3d.py +++ b/mne/viz/tests/test_3d.py @@ -2,6 +2,7 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +from contextlib import nullcontext from pathlib import Path import matplotlib.pyplot as plt @@ -262,6 +263,161 @@ def _assert_n_actors(fig, renderer, n_actors): assert len(fig.plotter.renderer.actors) == n_actors +@pytest.mark.slowtest # can be slow on OSX +@testing.requires_testing_data +@pytest.mark.parametrize( + "test_ecog, test_seeg, sensor_colors, sensor_scales, expectation", + [ + ( + True, + True, + "k", + 2, + pytest.raises( + TypeError, + match="sensor_colors must be an instance of dict or " + "None when more than one channel type", + ), + ), + ( + True, + True, + {"ecog": "k", "seeg": "k"}, + 2, + pytest.raises( + TypeError, + match="sensor_scales must be an instance of dict or " + "None when more than one channel type", + ), + ), + ( + True, + True, + {"ecog": ["k"] * 2, "seeg": "k"}, + {"ecog": 2, "seeg": 2}, + pytest.raises( + ValueError, + match=r"Invalid value for the 'len\(sensor_colors\['ecog'\]\)' " + r"parameter. Allowed values are \d+ and \d+, but got \d+ instead", + ), + ), + ( + True, + True, + {"ecog": "k", "seeg": ["k"] * 2}, + {"ecog": 2, "seeg": 2}, + pytest.raises( + ValueError, + match=r"Invalid value for the 'len\(sensor_colors\['seeg'\]\)' " + r"parameter. Allowed values are \d+ and \d+, but got \d+ instead", + ), + ), + ( + True, + True, + {"ecog": "k", "seeg": "k"}, + {"ecog": [2] * 2, "seeg": 2}, + pytest.raises( + ValueError, + match=r"Invalid value for the 'len\(sensor_scales\['ecog'\]\)' " + r"parameter. Allowed values are \d+ and \d+, but got \d+ instead", + ), + ), + ( + True, + True, + {"ecog": "k", "seeg": "k"}, + {"ecog": 2, "seeg": [2] * 2}, + pytest.raises( + ValueError, + match=r"Invalid value for the 'len\(sensor_scales\['seeg'\]\)' " + r"parameter. Allowed values are \d+ and \d+, but got \d+ instead", + ), + ), + ( + True, + True, + {"ecog": "NotAColor", "seeg": "NotAColor"}, + {"ecog": 2, "seeg": 2}, + pytest.raises( + ValueError, + match=r".* is not a valid color value", + ), + ), + ( + True, + True, + {"ecog": "k", "seeg": "k"}, + {"ecog": "k", "seeg": 2}, + pytest.raises( + AssertionError, + match=r"scales for .* must contain only numerical values, got .* " + r"instead.", + ), + ), + ( + True, + True, + {"ecog": "k", "seeg": "k"}, + {"ecog": 2, "seeg": 2}, + nullcontext(), + ), + ( + True, + True, + {"ecog": [0, 0, 0], "seeg": [0, 0, 0]}, + {"ecog": 2, "seeg": 2}, + nullcontext(), + ), + ( + True, + True, + {"ecog": ["k"] * 10, "seeg": ["k"] * 10}, + {"ecog": [2] * 10, "seeg": [2] * 10}, + nullcontext(), + ), + ( + True, + False, + "k", + 2, + nullcontext(), + ), + ], +) +def test_plot_alignment_ieeg( + renderer, test_ecog, test_seeg, sensor_colors, sensor_scales, expectation +): + """Test plotting of iEEG sensors.""" + # Load evoked: + evoked = read_evokeds(evoked_fname)[0] + # EEG only + evoked_eeg = evoked.copy().pick_types(eeg=True) + with evoked_eeg.info._unlock(): + evoked_eeg.info["projs"] = [] # "remove" avg proj + eeg_channels = pick_types(evoked_eeg.info, eeg=True) + # Set 10 EEG channels to ecog, 10 to seeg + evoked_eeg.set_channel_types( + {evoked_eeg.ch_names[ch]: "ecog" for ch in eeg_channels[:10]} + ) + evoked_eeg.set_channel_types( + {evoked_eeg.ch_names[ch]: "seeg" for ch in eeg_channels[10:20]} + ) + evoked_ecog_seeg = evoked_eeg.pick_types(seeg=True, ecog=True) + this_info = evoked_ecog_seeg.info + # Test plot: + with expectation: + fig = plot_alignment( + this_info, + ecog=test_ecog, + seeg=test_seeg, + sensor_colors=sensor_colors, + sensor_scales=sensor_scales, + ) + assert isinstance(fig, Figure3D) + renderer.backend._close_all() + + @pytest.mark.slowtest # Slow on Azure @testing.requires_testing_data # all use trans + head surf @pytest.mark.parametrize(