Skip to content

Commit

Permalink
CLN: fix linting
Browse files Browse the repository at this point in the history
  • Loading branch information
himkwtn committed Aug 26, 2024
1 parent d971225 commit de58118
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 25 deletions.
5 changes: 3 additions & 2 deletions pysindy/optimizers/constrained_sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ class ConstrainedSR3(SR3):
thresholder : string, optional (default 'l0')
Regularization function to use. Currently implemented options
are 'l0' (l0 norm), 'l1' (l1 norm), 'l2' (l2 norm), 'weighted_l0' (weighted l0 norm),
'weighted_l1' (weighted l1 norm), and 'weighted_l2' (weighted l2 norm).
are 'l0' (l0 norm), 'l1' (l1 norm), 'l2' (l2 norm),
'weighted_l0' (weighted l0 norm), 'weighted_l1' (weighted l1 norm),
and 'weighted_l2' (weighted l2 norm).
max_iter : int, optional (default 30)
Maximum iterations of the optimization algorithm.
Expand Down
31 changes: 16 additions & 15 deletions pysindy/utils/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import warnings
from functools import wraps
from itertools import repeat
from typing import Callable
from typing import Sequence
Expand Down Expand Up @@ -156,18 +155,22 @@ def reorder_constraints(arr, n_features, output_order="feature"):

def validate_prox_and_reg_inputs(func, regularization):
def wrapper(x, regularization_weight):
if regularization[:8] == 'weighted':
if regularization[:8] == "weighted":
if not isinstance(regularization_weight, np.ndarray):
raise ValueError(
f"'regularization_weight' must be an array of shape {x.shape}.")
if regularization_weight.shape != x.shape:
f"'regularization_weight' must be an array of shape {x.shape}."
)
weight_shape = regularization_weight.shape
if weight_shape != x.shape:
raise ValueError(
f"Invalid shape for 'regularization_weight': {
regularization_weight.shape}. Must be the same shape as x: {x.shape}."
f"Invalid shape for 'regularization_weight': \
{weight_shape}. Must be the same shape as x: {x.shape}."
)
else:
if not isinstance(regularization_weight, (int, float)) \
and (isinstance(regularization_weight, np.ndarray) and regularization_weight.shape not in [(1, 1), (1,)]):
if not isinstance(regularization_weight, (int, float)) and (
isinstance(regularization_weight, np.ndarray)
and regularization_weight.shape not in [(1, 1), (1,)]
):
raise ValueError("'regularization_weight' must be a scalar")
return func(x, regularization_weight)

Expand Down Expand Up @@ -260,19 +263,15 @@ def regualization_weighted_l0(
):
return np.sum(regularization_weight[np.nonzero(x)])

def regularization_l1(
x: NDArray[np.float64], regularization_weight: np.float64
):
def regularization_l1(x: NDArray[np.float64], regularization_weight: np.float64):
return np.sum(regularization_weight * np.abs(x))

def regualization_weighted_l1(
x: NDArray[np.float64], regularization_weight: NDArray[np.float64]
):
return np.sum(regularization_weight * np.abs(x))

def regularization_l2(
x: NDArray[np.float64], regularization_weight: np.float64
):
def regularization_l2(x: NDArray[np.float64], regularization_weight: np.float64):
return np.sum(regularization_weight * x**2)

def regualization_weighted_l2(
Expand All @@ -289,7 +288,9 @@ def regualization_weighted_l2(
"weighted_l2": regualization_weighted_l2,
}
regularization = regularization.lower()
return validate_prox_and_reg_inputs(regularization_fn[regularization], regularization)
return validate_prox_and_reg_inputs(
regularization_fn[regularization], regularization
)


def capped_simplex_projection(trimming_array, trimming_fraction):
Expand Down
28 changes: 20 additions & 8 deletions test/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,33 +85,42 @@ def test_get_regularization(regularization, lam, expected):

@pytest.mark.parametrize("regularization", ["l0", "l1", "l2"])
@pytest.mark.parametrize("lam", [1, np.array([1]), np.array([[1]])])
def test_get_regularization_shape(regularization, lam):
def test_get_prox_and_regularization_shape(regularization, lam):
data = np.array([[-2, 5]]).T
reg = get_regularization(regularization)
result = reg(data, lam)
assert result != None
reg_result = reg(data, lam)
prox = get_prox(regularization)
prox_result = prox(data, lam)
assert reg_result is not None
assert prox_result is not None


@pytest.mark.parametrize(
"regularization", ["weighted_l0", "weighted_l1", "weighted_l2"]
)
@pytest.mark.parametrize("lam", [np.array([[1, 2]]).T])
def test_get_weighted_regularization_shape(regularization, lam):
def test_get_weighted_prox_and_regularization_shape(regularization, lam):
data = np.array([[-2, 5]]).T
reg = get_regularization(regularization)
result = reg(data, lam)
assert result != None
reg_result = reg(data, lam)
prox = get_prox(regularization)
prox_result = prox(data, lam)
assert reg_result is not None
assert prox_result is not None


@pytest.mark.parametrize("regularization", ["l0", "l1", "l2"])
@pytest.mark.parametrize(
"lam", [np.array([[1, 2]]), np.array([1, 2]), np.array([[1, 2]]).T]
)
def test_get_regularization_bad_shape(regularization, lam):
def test_get_prox_and_regularization_bad_shape(regularization, lam):
data = np.array([[-2, 5]]).T
reg = get_regularization(regularization)
with pytest.raises(ValueError):
reg(data, lam)
prox = get_prox(regularization)
with pytest.raises(ValueError):
prox(data, lam)


@pytest.mark.parametrize(
Expand All @@ -120,11 +129,14 @@ def test_get_regularization_bad_shape(regularization, lam):
@pytest.mark.parametrize(
"lam", [np.array([[1, 2]]), np.array([1, 2, 3]), np.array([[1, 2, 3]]).T, 1]
)
def test_get_weighted_regularization_bad_shape(regularization, lam):
def test_get_weighted_prox_and_regularization_bad_shape(regularization, lam):
data = np.array([[-2, 5]]).T
reg = get_regularization(regularization)
with pytest.raises(ValueError):
reg(data, lam)
prox = get_prox(regularization)
with pytest.raises(ValueError):
prox(data, lam)


@pytest.mark.parametrize(
Expand Down

0 comments on commit de58118

Please sign in to comment.