Skip to content

Commit

Permalink
Sensor scales (mne-tools#12805)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Larson <[email protected]>
  • Loading branch information
3 people authored Aug 23, 2024
1 parent 99dd0e1 commit 0aae72d
Show file tree
Hide file tree
Showing 8 changed files with 359 additions and 5 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/12805.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added support for ``sensor_scales`` to :meth:`mne.viz.Brain.add_sensors` and :func:`mne.viz.plot_alignment`, by :newcontrib:`Alex Lepauvre`.
1 change: 1 addition & 0 deletions doc/changes/names.inc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
.. _Alex Ciok: https://github.com/alexCiok
.. _Alex Gramfort: https://alexandre.gramfort.net
.. _Alex Kiefer: https://home.alexk101.dev
.. _Alex Lepauvre: https://github.com/AlexLepauvre
.. _Alex Rockhill: https://github.com/alexrockhill/
.. _Alexander Rudiuk: https://github.com/ARudiuk
.. _Alexandre Barachant: https://alexandre.barachant.org
Expand Down
1 change: 1 addition & 0 deletions mne/decoding/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def _make_data(n_samples=1000, n_features=5, n_targets=3):
return X, Y, A


@pytest.mark.filterwarnings("ignore:invalid value encountered in cast.*:RuntimeWarning")
def test_get_coef():
"""Test getting linear coefficients (filters/patterns) from estimators."""
from sklearn import svm
Expand Down
13 changes: 13 additions & 0 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3997,6 +3997,19 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
shape ``(n_eeg, 3)`` or ``(n_eeg, 4)``.
"""

docdict["sensor_scales"] = """
sensor_scales : int | float | array-like | dict | None
Scale to use for the sensor glyphs. Can be None (default) to use default scale.
A dict should provide the Scale (values) for each channel type (keys), e.g.::
dict(eeg=eeg_scales)
Where the value (``eeg_scales`` above) can be broadcast to an array of values with
length that matches the number of channels of that type. A few examples of this
for the case above are the value ``10e-3``, a list of ``n_eeg`` values, or an NumPy
ndarray of shape ``(n_eeg,)``.
"""

docdict["sensors_topomap"] = """
sensors : bool | str
Whether to add markers for sensor locations. If :class:`str`, should be a
Expand Down
76 changes: 71 additions & 5 deletions mne/viz/_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,8 @@ def plot_alignment(
fig=None,
interaction="terrain",
sensor_colors=None,
*,
sensor_scales=None,
verbose=None,
):
"""Plot head, sensor, and source space alignment in 3D.
Expand Down Expand Up @@ -617,6 +619,9 @@ def plot_alignment(
.. versionchanged:: 1.6
Support for passing a ``dict`` was added.
%(sensor_scales)s
.. versionadded:: 1.9
%(verbose)s
Returns
Expand Down Expand Up @@ -906,6 +911,7 @@ def plot_alignment(
"m",
sensor_alpha=sensor_alpha,
sensor_colors=sensor_colors,
sensor_scales=sensor_scales,
)

if src is not None:
Expand Down Expand Up @@ -1480,6 +1486,7 @@ def _plot_sensors_3d(
check_inside=None,
nearest=None,
sensor_colors=None,
sensor_scales=None,
):
"""Render sensors in a 3D scene."""
from matplotlib.colors import to_rgba_array
Expand Down Expand Up @@ -1535,23 +1542,44 @@ def _plot_sensors_3d(
sensor_colors = {
list(locs)[0]: to_rgba_array(sensor_colors),
}
if sensor_scales is not None and not isinstance(sensor_scales, dict):
sensor_scales = {
list(locs)[0]: sensor_scales,
}
else:
extra = f"when more than one channel type ({list(locs)}) is plotted"
_validate_type(sensor_colors, types, "sensor_colors", extra=extra)
_validate_type(sensor_scales, types, "sensor_scales", extra=extra)
del extra, types
if sensor_colors is None:
sensor_colors = dict()
if sensor_scales is None:
sensor_scales = dict()
assert isinstance(sensor_colors, dict)
assert isinstance(sensor_scales, dict)
for ch_type, sens_loc in locs.items():
logger.debug(f"Drawing {ch_type} sensors")
assert len(sens_loc) # should be guaranteed above
colors = to_rgba_array(sensor_colors.get(ch_type, defaults[ch_type + "_color"]))
scales = np.atleast_1d(
sensor_scales.get(ch_type, defaults[ch_type + "_scale"] * unit_scalar)
)
_check_option(
f"len(sensor_colors[{repr(ch_type)}])",
colors.shape[0],
(len(sens_loc), 1),
)
scale = defaults[ch_type + "_scale"] * unit_scalar
_check_option(
f"len(sensor_scales[{repr(ch_type)}])",
scales.shape[0],
(len(sens_loc), 1),
)
# Check that the scale is numerical
assert np.issubdtype(scales.dtype, np.number), (
f"scales for {ch_type} must contain only numerical values, "
f"got {scales} instead."
)

this_alpha = sensor_alpha[ch_type]
if isinstance(sens_loc[0], dict): # meg coil
if len(colors) == 1:
Expand All @@ -1567,13 +1595,13 @@ def _plot_sensors_3d(
else:
sens_loc = np.array(sens_loc, float)
mask = ~np.isnan(sens_loc).any(axis=1)
if len(colors) == 1:
if len(colors) == 1 and len(scales) == 1:
# Single color mode (one actor)
actor, _ = _plot_glyphs(
renderer=renderer,
loc=sens_loc[mask] * unit_scalar,
color=colors[0, :3],
scale=scale,
scale=scales[0],
opacity=this_alpha * colors[0, 3],
orient_glyphs=orient_glyphs,
scale_by_distance=scale_by_distance,
Expand All @@ -1583,9 +1611,47 @@ def _plot_sensors_3d(
nearest=nearest,
)
actors[ch_type].append(actor)
else:
# Multi-color mode (multiple actors)
elif len(colors) == len(sens_loc) and len(scales) == 1:
# Multi-color single scale mode (multiple actors)
for loc, color, usable in zip(sens_loc, colors, mask):
if not usable:
continue
actor, _ = _plot_glyphs(
renderer=renderer,
loc=loc * unit_scalar,
color=color[:3],
scale=scales[0],
opacity=this_alpha * color[3],
orient_glyphs=orient_glyphs,
scale_by_distance=scale_by_distance,
project_points=project_points,
surf=surf,
check_inside=check_inside,
nearest=nearest,
)
actors[ch_type].append(actor)
elif len(colors) == 1 and len(scales) == len(sens_loc):
# Multi-scale single color mode (multiple actors)
for loc, scale, usable in zip(sens_loc, scales, mask):
if not usable:
continue
actor, _ = _plot_glyphs(
renderer=renderer,
loc=loc * unit_scalar,
color=colors[0, :3],
scale=scale,
opacity=this_alpha * colors[0, 3],
orient_glyphs=orient_glyphs,
scale_by_distance=scale_by_distance,
project_points=project_points,
surf=surf,
check_inside=check_inside,
nearest=nearest,
)
actors[ch_type].append(actor)
else:
# Multi-color multi-scale mode (multiple actors)
for loc, color, scale, usable in zip(sens_loc, colors, scales, mask):
if not usable:
continue
actor, _ = _plot_glyphs(
Expand Down
5 changes: 5 additions & 0 deletions mne/viz/_brain/_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2690,6 +2690,7 @@ def add_sensors(
max_dist=0.004,
*,
sensor_colors=None,
sensor_scales=None,
verbose=None,
):
"""Add mesh objects to represent sensor positions.
Expand All @@ -2708,6 +2709,9 @@ def add_sensors(
%(sensor_colors)s
.. versionadded:: 1.6
%(sensor_scales)s
.. versionadded:: 1.9
%(verbose)s
Notes
Expand Down Expand Up @@ -2764,6 +2768,7 @@ def add_sensors(
self._units,
sensor_alpha=sensor_alpha,
sensor_colors=sensor_colors,
sensor_scales=sensor_scales,
)
# sensors_actors can still be None
for item, actors in (sensors_actors or {}).items():
Expand Down
111 changes: 111 additions & 0 deletions mne/viz/_brain/tests/test_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import os
import platform
from contextlib import nullcontext
from pathlib import Path
from shutil import copyfile

Expand Down Expand Up @@ -544,6 +545,116 @@ def __init__(self):
brain.close()


# TODO: Figure out why brain_gc is problematic here on PyQt5
@pytest.mark.allow_unclosed
@testing.requires_testing_data
@pytest.mark.parametrize(
"sensor_colors, sensor_scales, expectation",
[
(
{"seeg": ["k"] * 5},
{"seeg": [2] * 6},
pytest.raises(
ValueError,
match=r"Invalid value for the 'len\(sensor_colors\['seeg'\]\)' "
r"parameter. Allowed values are \d+ and \d+, but got \d+ instead",
),
),
(
{"seeg": ["k"] * 6},
{"seeg": [2] * 5},
pytest.raises(
ValueError,
match=r"Invalid value for the 'len\(sensor_scales\['seeg'\]\)' "
r"parameter. Allowed values are \d+ and \d+, but got \d+ instead",
),
),
(
"NotAColor",
2,
pytest.raises(
ValueError,
match=r".* is not a valid color value",
),
),
(
"k",
"k",
pytest.raises(
AssertionError,
match=r"scales for .* must contain only numerical values, got .* "
r"instead.",
),
),
(
"k",
2,
nullcontext(),
),
(
["k"] * 6,
[2] * 6,
nullcontext(),
),
(
{"seeg": ["k"] * 6},
{"seeg": [2] * 6},
nullcontext(),
),
],
)
def test_add_sensors_scales(
renderer_interactive_pyvistaqt,
sensor_colors,
sensor_scales,
expectation,
):
"""Test sensor_scales parameter."""
kwargs = dict(subject=subject, subjects_dir=subjects_dir)
hemi = "lh"
surf = "white"
cortex = "low_contrast"
title = "test"
size = (300, 300)

brain = Brain(
hemi=hemi,
surf=surf,
size=size,
title=title,
cortex=cortex,
units="m",
silhouette=dict(decimate=0.95),
**kwargs,
)

proj_info = create_info([f"Ch{i}" for i in range(1, 7)], 1000, "seeg")
pos = (
np.array(
[
[25.85, 9.04, -5.38],
[33.56, 9.04, -5.63],
[40.44, 9.04, -5.06],
[46.75, 9.04, -6.78],
[-30.08, 9.04, 28.23],
[-32.95, 9.04, 37.99],
]
)
/ 1000
)
proj_info.set_montage(
make_dig_montage(ch_pos=dict(zip(proj_info.ch_names, pos)), coord_frame="head")
)
with expectation:
brain.add_sensors(
proj_info,
trans=fname_trans,
sensor_colors=sensor_colors,
sensor_scales=sensor_scales,
)
brain.close()


def _assert_view_allclose(
brain,
roll,
Expand Down
Loading

0 comments on commit 0aae72d

Please sign in to comment.