diff --git a/Makefile b/Makefile index e8c7849..e6c0a97 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,10 @@ LINT_TARGET_DIRS := PyEMD doc example +init: + python -m venv .venv + .venv/bin/pip install -r requirements.txt + @echo "Run 'source .venv/bin/activate' to activate the virtual environment" + test: python -m PyEMD.tests.test_all diff --git a/PyEMD/EMD.py b/PyEMD/EMD.py index a2137b6..d52305d 100644 --- a/PyEMD/EMD.py +++ b/PyEMD/EMD.py @@ -13,7 +13,7 @@ from scipy.interpolate import interp1d from PyEMD.splines import akima, cubic, cubic_hermite, cubic_spline_3pts, pchip -from PyEMD.utils import get_timeline +from PyEMD.utils import get_timeline, deduce_common_type FindExtremaOutput = Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray] @@ -765,7 +765,7 @@ def check_imf( @staticmethod def _common_dtype(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """Casts inputs (x, y) into a common numpy DTYPE.""" - dtype = np.find_common_type([x.dtype, y.dtype], []) + dtype = deduce_common_type(x.dtype, y.dtype) if x.dtype != dtype: x = x.astype(dtype) if y.dtype != dtype: diff --git a/PyEMD/tests/test_checks.py b/PyEMD/tests/test_checks.py index 874f359..c4801f1 100644 --- a/PyEMD/tests/test_checks.py +++ b/PyEMD/tests/test_checks.py @@ -77,13 +77,13 @@ def test_whitenoise_check_rescaling_imf(self): def test_whitenoise_check_nan_values(self): """whitenoise check with nan in IMF.""" - S = np.array([np.full(100, np.NaN) for i in range(5, 0, -1)]) + S = np.array([np.full(100, np.nan) for i in range(5, 0, -1)]) res = whitenoise_check(S) - self.assertEqual(res, None, "Input NaN returns None") + self.assertEqual(res, None, "Input nan returns None") def test_invalid_alpha(self): """Test if invalid alpha return AssertionError.""" - S = np.array([np.full(100, np.NaN) for i in range(5, 0, -1)]) + S = np.array([np.full(100, np.nan) for i in range(5, 0, -1)]) self.assertRaises(AssertionError, whitenoise_check, S, alpha=1) self.assertRaises(AssertionError, whitenoise_check, S, alpha=0) self.assertRaises(AssertionError, whitenoise_check, S, alpha=-10) @@ -99,7 +99,7 @@ def test_invalid_test_name(self): def test_invalid_input_type(self): """Test if invalid input type return AssertionError.""" - S = [np.full(100, np.NaN) for i in range(5, 0, -1)] + S = [np.full(100, np.nan) for i in range(5, 0, -1)] self.assertRaises(AssertionError, whitenoise_check, S) self.assertRaises(AssertionError, whitenoise_check, 1) self.assertRaises(AssertionError, whitenoise_check, 1.2) diff --git a/PyEMD/tests/test_utils.py b/PyEMD/tests/test_utils.py index 63fbc6e..f90a56b 100644 --- a/PyEMD/tests/test_utils.py +++ b/PyEMD/tests/test_utils.py @@ -2,7 +2,7 @@ import numpy as np -from PyEMD.utils import get_timeline +from PyEMD.utils import get_timeline, deduce_common_type class MyTestCase(unittest.TestCase): @@ -31,15 +31,11 @@ def test_get_timeline_does_not_overflow_int16(self): self.assertEqual(T[-1], len(S) - 1, "Range is kept") self.assertEqual(T.dtype, np.uint16, "UInt16 is the min type that matches requirements") - def test_get_timeline_does_not_overflow_float16(self): - S = np.random.random(int(np.finfo(np.float16).max) + 5).astype(dtype=np.float16) - T = get_timeline(len(S), dtype=S.dtype) - - self.assertGreater(len(S), np.finfo(S.dtype).max, "Length of the signal is greater than its type max value") - self.assertEqual(len(T), len(S), "Lengths must be equal") - self.assertEqual(T[-1], len(S) - 1, "Range is kept") - self.assertEqual(T.dtype, np.float32, "Float32 is the min type that matches requirements") - + def test_deduce_common_types(self): + self.assertEqual(deduce_common_type(np.int16, np.int32), np.int32) + self.assertEqual(deduce_common_type(np.int32, np.int16), np.int32) + self.assertEqual(deduce_common_type(np.int32, np.int32), np.int32) + self.assertEqual(deduce_common_type(np.float32, np.float64), np.float64) if __name__ == "__main__": unittest.main() diff --git a/PyEMD/utils.py b/PyEMD/utils.py index 3330ad3..5dc9657 100644 --- a/PyEMD/utils.py +++ b/PyEMD/utils.py @@ -1,4 +1,5 @@ from typing import Optional +from functools import cache import numpy as np @@ -50,3 +51,14 @@ def smallest_inclusive_dtype(ref_dtype: np.dtype, ref_value) -> np.dtype: raise ValueError("Requested too large integer range. Exceeds max( float64 ) == '{}.".format(max_val)) raise ValueError("Unsupported dtype '{}'. Only intX and floatX are supported.".format(ref_dtype)) + +@cache +def deduce_common_type(xtype: np.dtype, ytype: np.dtype) -> np.dtype: + if xtype == ytype: + return xtype + if np.version.version[0] == '1': + dtype = np.find_common_type([xtype, ytype], []) + else: + dtype = np.promote_types(xtype, ytype) + return dtype + \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 094e97f..77c4133 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ numpy>=1.12 scipy>=0.19 pathos>=0.2.1 +tqdm>=4.64.0,<5.0