Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ssr #559

Merged
merged 5 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions pysindy/optimizers/ssr.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,10 @@ def _reduce(self, x, y):
" ... {: >10} ... {: >10}".format(*row)
)

self.err_history_ = []
self.history_ = [coef]
self.err_history_ = [
np.sum((y - x @ coef.T) ** 2) + l0_penalty * np.count_nonzero(coef)
]
for k in range(self.max_iter):
for i in range(n_targets):
if self.criteria == "coefficient_value":
Expand Down Expand Up @@ -226,5 +229,25 @@ def _reduce(self, x, y):
if np.all(np.sum(np.asarray(inds, dtype=int), axis=1) <= 1):
# each equation has one last term
break
err_min = np.argmin(self.err_history_)
self.coef_ = np.asarray(self.history_)[err_min, :, :]

if self.kappa is not None:
ind_best = np.argmin(self.err_history_)
else:
# err history is reverse of ordering in paper
ind_best = (
len(self.err_history_) - 1 - _ind_inflection(self.err_history_[::-1])
)
self.coef_ = np.asarray(self.history_)[ind_best, :, :]


def _ind_inflection(err_descending: list[float]) -> int:
"Calculate the index of the inflection point in error"
if len(err_descending) == 1:
raise ValueError("Cannot find the inflection point of a single point")
err_descending = np.array(err_descending)
if np.any(err_descending < 0):
raise ValueError("SSR inflection point method requires nonnegative losses")
if np.any(err_descending == 0):
return np.argmin(err_descending)
err_ratio = err_descending[:-1] / err_descending[1:]
return np.argmax(err_ratio) + 1
42 changes: 42 additions & 0 deletions test/test_optimizers/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from pysindy.optimizers import STLSQ
from pysindy.optimizers import TrappingSR3
from pysindy.optimizers import WrappedOptimizer
from pysindy.optimizers.ssr import _ind_inflection
from pysindy.optimizers.stlsq import _remove_and_decrement
from pysindy.utils import supports_multiple_targets
from pysindy.utils.odes import enzyme
Expand Down Expand Up @@ -78,6 +79,8 @@ def _align_optimizer_and_1dfeatures(
if isinstance(opt, TrappingSR3):
opt = TrappingSR3(_n_tgts=1, _include_bias=False)
features = np.hstack([features, features])
elif isinstance(opt, SSR):
features = np.hstack([features, features])
else:
features = features
return opt, features
Expand Down Expand Up @@ -1194,3 +1197,42 @@ def test_pickle(data_lorenz, opt_cls, opt_args):
new_opt = pickle.loads(pickle.dumps(opt))
result = new_opt.coef_
np.testing.assert_array_equal(result, expected)


@pytest.mark.parametrize("kappa", (None, 0.1), ids=["inflection", "L0"])
def test_ssr_history_selection(kappa):
rng = np.random.default_rng(1)
x = rng.normal(size=(30, 8))
expected = np.array([[1, 1, 1, 0, 0, 0, 0, 0]])
y = x @ expected.T

x += np.random.normal(size=(30, 8), scale=1e-2)
opt = SSR(kappa=kappa)
result = opt.fit(x, y).coef_

assert len(opt.history_) == len(opt.err_history_)
np.testing.assert_allclose(result, expected, atol=1e-2)
np.testing.assert_array_equal(result == 0, expected == 0)


@pytest.mark.parametrize(
["errs", "expected"],
(([3, 1, 0.9], 1), ([1, 0, 0], 1)),
ids=["basic", "zero-error"],
)
def test_ssr_inflection(errs, expected):
result = _ind_inflection(errs)
assert result == expected


@pytest.mark.parametrize(
["errs", "expected", "message"],
(
([1], ValueError, "single point"),
([-1, 1, 1], ValueError, ""),
),
ids=["length-1", "negative"],
)
def test_ssr_inflection_bad_args(errs, expected, message):
with pytest.raises(expected, match=message):
_ind_inflection(errs)
Loading