From e999e853b263c8b48a6fcfeb60c82b445e20d88b Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Tue, 17 Sep 2024 01:29:01 +0200 Subject: [PATCH] [BUG] Allow `Epochs.compute_tfr()` for the multitaper method and complex/phase outputs (#12842) --- .github/workflows/tests.yml | 2 +- doc/changes/devel/12842.bugfix.rst | 1 + mne/decoding/receptive_field.py | 9 +++++---- mne/decoding/tests/test_search_light.py | 1 + mne/decoding/time_delaying_ridge.py | 22 +++++++++++++--------- mne/time_frequency/tests/test_tfr.py | 7 +++++++ mne/time_frequency/tfr.py | 3 ++- 7 files changed, 30 insertions(+), 15 deletions(-) create mode 100644 doc/changes/devel/12842.bugfix.rst diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 571c3943ae7..571d4329831 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -131,4 +131,4 @@ jobs: - uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} - if: success() + if: always() diff --git a/doc/changes/devel/12842.bugfix.rst b/doc/changes/devel/12842.bugfix.rst new file mode 100644 index 00000000000..75f83683b8f --- /dev/null +++ b/doc/changes/devel/12842.bugfix.rst @@ -0,0 +1 @@ +Fix bug where :meth:`mne.Epochs.compute_tfr` could not be used with the multitaper method and complex or phase outputs, by `Thomas Binns`_. \ No newline at end of file diff --git a/mne/decoding/receptive_field.py b/mne/decoding/receptive_field.py index 5fc985a81b3..a9cc72d18ce 100644 --- a/mne/decoding/receptive_field.py +++ b/mne/decoding/receptive_field.py @@ -117,7 +117,7 @@ def __init__( ): self.tmin = tmin self.tmax = tmax - self.sfreq = float(sfreq) + self.sfreq = sfreq self.feature_names = feature_names self.estimator = 0.0 if estimator is None else estimator self.fit_intercept = fit_intercept @@ -154,7 +154,7 @@ def _delay_and_reshape(self, X, y=None): X, self.tmin, self.tmax, - self.sfreq, + self.sfreq_, fill_mean=self.fit_intercept_, ) X = _reshape_for_est(X) @@ -182,12 +182,13 @@ def fit(self, X, y): raise ValueError( f"scoring must be one of {sorted(_SCORERS.keys())}, got {self.scoring} " ) + self.sfreq_ = float(self.sfreq) X, y, _, self._y_dim = self._check_dimensions(X, y) if self.tmin > self.tmax: raise ValueError(f"tmin ({self.tmin}) must be at most tmax ({self.tmax})") # Initialize delays - self.delays_ = _times_to_delays(self.tmin, self.tmax, self.sfreq) + self.delays_ = _times_to_delays(self.tmin, self.tmax, self.sfreq_) # Define the slice that we should use in the middle self.valid_samples_ = _delays_to_slice(self.delays_) @@ -200,7 +201,7 @@ def fit(self, X, y): estimator = TimeDelayingRidge( self.tmin, self.tmax, - self.sfreq, + self.sfreq_, alpha=self.estimator, fit_intercept=self.fit_intercept_, n_jobs=self.n_jobs, diff --git a/mne/decoding/tests/test_search_light.py b/mne/decoding/tests/test_search_light.py index fe605abca06..9e15a1df59b 100644 --- a/mne/decoding/tests/test_search_light.py +++ b/mne/decoding/tests/test_search_light.py @@ -358,6 +358,7 @@ def test_sklearn_compliance(estimator, check): "check_transformer_data_not_an_array", "check_n_features_in", "check_fit2d_predict1d", + "check_do_not_raise_errors_in_init_or_set_params", ) if any(ignore in str(check) for ignore in ignores): return diff --git a/mne/decoding/time_delaying_ridge.py b/mne/decoding/time_delaying_ridge.py index e824a15be75..b80b36d3922 100644 --- a/mne/decoding/time_delaying_ridge.py +++ b/mne/decoding/time_delaying_ridge.py @@ -287,12 +287,10 @@ def __init__( n_jobs=None, edge_correction=True, ): - if tmin > tmax: - raise ValueError(f"tmin must be <= tmax, got {tmin} and {tmax}") - self.tmin = float(tmin) - self.tmax = float(tmax) - self.sfreq = float(sfreq) - self.alpha = float(alpha) + self.tmin = tmin + self.tmax = tmax + self.sfreq = sfreq + self.alpha = alpha self.reg_type = reg_type self.fit_intercept = fit_intercept self.edge_correction = edge_correction @@ -300,11 +298,11 @@ def __init__( @property def _smin(self): - return int(round(self.tmin * self.sfreq)) + return int(round(self.tmin_ * self.sfreq_)) @property def _smax(self): - return int(round(self.tmax * self.sfreq)) + 1 + return int(round(self.tmax_ * self.sfreq_)) + 1 def fit(self, X, y): """Estimate the coefficients of the linear model. @@ -323,6 +321,12 @@ def fit(self, X, y): """ _validate_type(X, "array-like", "X") _validate_type(y, "array-like", "y") + self.tmin_ = float(self.tmin) + self.tmax_ = float(self.tmax) + self.sfreq_ = float(self.sfreq) + self.alpha_ = float(self.alpha) + if self.tmin_ > self.tmax_: + raise ValueError(f"tmin must be <= tmax, got {self.tmin_} and {self.tmax_}") X = np.asarray(X, dtype=float) y = np.asarray(y, dtype=float) if X.ndim == 3: @@ -349,7 +353,7 @@ def fit(self, X, y): self.edge_correction, ) self.coef_ = _fit_corrs( - self.cov_, x_y_, n_ch_x, self.reg_type, self.alpha, n_ch_x + self.cov_, x_y_, n_ch_x, self.reg_type, self.alpha_, n_ch_x ) # This is the sklearn formula from LinearModel (will be 0. for no fit) if self.fit_intercept: diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index 38099d8a3aa..cd3a97ab90a 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -1530,6 +1530,13 @@ def test_epochs_compute_tfr_stockwell(epochs, freqs, return_itc): assert tfr.comment == "1" +@pytest.mark.parametrize("output", ("complex", "phase")) +def test_epochs_compute_tfr_multitaper_complex_phase(epochs, output): + """Test Epochs.compute_tfr(output="complex"/"phase").""" + tfr = epochs.compute_tfr("multitaper", freqs_linspace, output=output) + assert len(tfr.shape) == 5 + + @pytest.mark.parametrize("copy", (False, True)) def test_epochstfr_iter_evoked(epochs_tfr, copy): """Test EpochsTFR.iter_evoked().""" diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 98af67ff0a7..eaf173092bb 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -1533,7 +1533,8 @@ def _compute_tfr(self, data, n_jobs, verbose): ] # deal with the "taper" dimension if self._needs_taper_dim: - expected_shape.insert(1, self._data.shape[1]) + tapers_dim = 1 if _get_instance_type_string(self) != "Epochs" else 2 + expected_shape.insert(1, self._data.shape[tapers_dim]) self._shape = tuple(expected_shape) @verbose