diff --git a/test/test_pysindy.py b/test/test_pysindy.py index 9a4630166..17f1af3fe 100644 --- a/test/test_pysindy.py +++ b/test/test_pysindy.py @@ -22,6 +22,7 @@ from sklearn.model_selection import TimeSeriesSplit from sklearn.utils.validation import check_is_fitted +from pysindy import pysindy from pysindy import SINDy from pysindy.differentiation import SINDyDerivative from pysindy.differentiation import SmoothedFiniteDifference @@ -436,6 +437,14 @@ def test_score_discrete_time(data_discrete_time): assert model.score(x, x_dot=x) < 1 +def test_bad_multiple_trajectories(data_multiple_trajectories): + x, t = data_multiple_trajectories + with pytest.raises(TypeError): + pysindy._check_multiple_trajectories(x, x_dot=x[0], u=None) + with pytest.raises(ValueError): + pysindy._check_multiple_trajectories(x, x_dot=x[:-1], u=None) + + def test_fit_discrete_time_multiple_trajectories( data_discrete_time_multiple_trajectories, ):