Skip to content

Commit

Permalink
BUG: fix cvxpy regularization calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
himkwtn committed Aug 20, 2024
1 parent b35f24c commit 86fef20
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 23 deletions.
24 changes: 15 additions & 9 deletions pysindy/optimizers/constrained_sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,20 +269,26 @@ def _update_full_coef_constraints(self, H, x_transpose_y, coef_sparse):
rhs = rhs.reshape(g.shape)
return inv1.dot(rhs)

@staticmethod
def _calculate_penalty(regularizer, lam, xi: cp.Variable) -> cp.Expression:
regularizer = regularizer.lower()
if regularizer == "l1":
return lam * cp.sum(cp.abs(xi))
elif regularizer == "weighted_l1":
return cp.sum(cp.multiply(np.ravel(lam), cp.abs(xi)))
elif regularizer == "l2":
return lam * cp.sum(xi**2)
elif regularizer == "weighted_l2":
return cp.sum(cp.multiply(np.ravel(lam), xi**2))

def _create_var_and_part_cost(
self, var_len: int, x_expanded: np.ndarray, y: np.ndarray
) -> Tuple[cp.Variable, cp.Expression]:
xi = cp.Variable(var_len)
cost = cp.sum_squares(x_expanded @ xi - y.flatten())
if self.thresholder.lower() == "l1":
cost = cost + self.threshold * cp.norm1(xi)
elif self.thresholder.lower() == "weighted_l1":
cost = cost + cp.norm1(np.ravel(self.thresholds) @ xi)
elif self.thresholder.lower() == "l2":
cost = cost + self.threshold * cp.norm2(xi) ** 2
elif self.thresholder.lower() == "weighted_l2":
cost = cost + cp.norm2(np.ravel(self.thresholds) @ xi) ** 2
return xi, cost
threshold = self.thresholds if self.thresholds is not None else self.threshold
penalty = self._calculate_penalty(self.thresholder, threshold, xi)
return xi, cost + penalty

def _update_coef_cvxpy(self, xi, cost, var_len, coef_prev, tol):
if self.use_constraints:
Expand Down
25 changes: 11 additions & 14 deletions test/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
import pickle

import cvxpy as cp
import numpy as np
import pytest
import scipy.linalg
Expand Down Expand Up @@ -473,23 +474,19 @@ def test_constrained_sr3_quadratic_library(params):


@pytest.mark.parametrize(
"params",
["regularizer", "lam", "expected"],
[
dict(thresholder="l1", threshold=1, expected=2.5),
dict(thresholder="weighted_l1", thresholds=np.ones((4, 1)), expected=2.5),
dict(thresholder="l2", threshold=1, expected=1.5),
dict(thresholder="weighted_l2", thresholds=np.ones((4, 1)), expected=2.5),
("l1", np.array([[2]]), 20),
("weighted_l1", np.array([[3, 2, 0.5]]).T, 14.5),
("l2", np.array([[2]]), 76),
("weighted_l2", np.array([[3, 2, 0.5]]).T, 42.5),
],
ids=lambda d: d["thresholder"],
)
def test_stable_linear_sr3_cost_function(params):
expected = params.pop("expected")
opt = StableLinearSR3(**params)
x = np.eye(2)
y = np.ones(1)
xi, cost = opt._create_var_and_part_cost(x.flatten(), y, x, x)
xi.value = 0.5 * np.ones(4)
np.testing.assert_allclose(cost.value, expected)
def test_constrained_sr3_penalty_term(regularizer, lam, expected):
xi = cp.Variable(3)
penalty = ConstrainedSR3._calculate_penalty(regularizer, lam, xi)
xi.value = np.array([-2, 3, 5])
np.testing.assert_allclose(penalty.value, expected)


def test_stable_linear_sr3_linear_library():
Expand Down

0 comments on commit 86fef20

Please sign in to comment.