Skip to content

Commit

Permalink
[BUG] Allow Epochs.compute_tfr() for the multitaper method and comp…
Browse files Browse the repository at this point in the history
…lex/phase outputs (mne-tools#12842)
  • Loading branch information
tsbinns authored Sep 16, 2024
1 parent 670330a commit e999e85
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,4 @@ jobs:
- uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
if: success()
if: always()
1 change: 1 addition & 0 deletions doc/changes/devel/12842.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug where :meth:`mne.Epochs.compute_tfr` could not be used with the multitaper method and complex or phase outputs, by `Thomas Binns`_.
9 changes: 5 additions & 4 deletions mne/decoding/receptive_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __init__(
):
self.tmin = tmin
self.tmax = tmax
self.sfreq = float(sfreq)
self.sfreq = sfreq
self.feature_names = feature_names
self.estimator = 0.0 if estimator is None else estimator
self.fit_intercept = fit_intercept
Expand Down Expand Up @@ -154,7 +154,7 @@ def _delay_and_reshape(self, X, y=None):
X,
self.tmin,
self.tmax,
self.sfreq,
self.sfreq_,
fill_mean=self.fit_intercept_,
)
X = _reshape_for_est(X)
Expand Down Expand Up @@ -182,12 +182,13 @@ def fit(self, X, y):
raise ValueError(
f"scoring must be one of {sorted(_SCORERS.keys())}, got {self.scoring} "
)
self.sfreq_ = float(self.sfreq)
X, y, _, self._y_dim = self._check_dimensions(X, y)

if self.tmin > self.tmax:
raise ValueError(f"tmin ({self.tmin}) must be at most tmax ({self.tmax})")
# Initialize delays
self.delays_ = _times_to_delays(self.tmin, self.tmax, self.sfreq)
self.delays_ = _times_to_delays(self.tmin, self.tmax, self.sfreq_)

# Define the slice that we should use in the middle
self.valid_samples_ = _delays_to_slice(self.delays_)
Expand All @@ -200,7 +201,7 @@ def fit(self, X, y):
estimator = TimeDelayingRidge(
self.tmin,
self.tmax,
self.sfreq,
self.sfreq_,
alpha=self.estimator,
fit_intercept=self.fit_intercept_,
n_jobs=self.n_jobs,
Expand Down
1 change: 1 addition & 0 deletions mne/decoding/tests/test_search_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def test_sklearn_compliance(estimator, check):
"check_transformer_data_not_an_array",
"check_n_features_in",
"check_fit2d_predict1d",
"check_do_not_raise_errors_in_init_or_set_params",
)
if any(ignore in str(check) for ignore in ignores):
return
Expand Down
22 changes: 13 additions & 9 deletions mne/decoding/time_delaying_ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,24 +287,22 @@ def __init__(
n_jobs=None,
edge_correction=True,
):
if tmin > tmax:
raise ValueError(f"tmin must be <= tmax, got {tmin} and {tmax}")
self.tmin = float(tmin)
self.tmax = float(tmax)
self.sfreq = float(sfreq)
self.alpha = float(alpha)
self.tmin = tmin
self.tmax = tmax
self.sfreq = sfreq
self.alpha = alpha
self.reg_type = reg_type
self.fit_intercept = fit_intercept
self.edge_correction = edge_correction
self.n_jobs = n_jobs

@property
def _smin(self):
return int(round(self.tmin * self.sfreq))
return int(round(self.tmin_ * self.sfreq_))

@property
def _smax(self):
return int(round(self.tmax * self.sfreq)) + 1
return int(round(self.tmax_ * self.sfreq_)) + 1

def fit(self, X, y):
"""Estimate the coefficients of the linear model.
Expand All @@ -323,6 +321,12 @@ def fit(self, X, y):
"""
_validate_type(X, "array-like", "X")
_validate_type(y, "array-like", "y")
self.tmin_ = float(self.tmin)
self.tmax_ = float(self.tmax)
self.sfreq_ = float(self.sfreq)
self.alpha_ = float(self.alpha)
if self.tmin_ > self.tmax_:
raise ValueError(f"tmin must be <= tmax, got {self.tmin_} and {self.tmax_}")
X = np.asarray(X, dtype=float)
y = np.asarray(y, dtype=float)
if X.ndim == 3:
Expand All @@ -349,7 +353,7 @@ def fit(self, X, y):
self.edge_correction,
)
self.coef_ = _fit_corrs(
self.cov_, x_y_, n_ch_x, self.reg_type, self.alpha, n_ch_x
self.cov_, x_y_, n_ch_x, self.reg_type, self.alpha_, n_ch_x
)
# This is the sklearn formula from LinearModel (will be 0. for no fit)
if self.fit_intercept:
Expand Down
7 changes: 7 additions & 0 deletions mne/time_frequency/tests/test_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,6 +1530,13 @@ def test_epochs_compute_tfr_stockwell(epochs, freqs, return_itc):
assert tfr.comment == "1"


@pytest.mark.parametrize("output", ("complex", "phase"))
def test_epochs_compute_tfr_multitaper_complex_phase(epochs, output):
"""Test Epochs.compute_tfr(output="complex"/"phase")."""
tfr = epochs.compute_tfr("multitaper", freqs_linspace, output=output)
assert len(tfr.shape) == 5


@pytest.mark.parametrize("copy", (False, True))
def test_epochstfr_iter_evoked(epochs_tfr, copy):
"""Test EpochsTFR.iter_evoked()."""
Expand Down
3 changes: 2 additions & 1 deletion mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,7 +1533,8 @@ def _compute_tfr(self, data, n_jobs, verbose):
]
# deal with the "taper" dimension
if self._needs_taper_dim:
expected_shape.insert(1, self._data.shape[1])
tapers_dim = 1 if _get_instance_type_string(self) != "Epochs" else 2
expected_shape.insert(1, self._data.shape[tapers_dim])
self._shape = tuple(expected_shape)

@verbose
Expand Down

0 comments on commit e999e85

Please sign in to comment.