Skip to content

Commit

Permalink
feat: support numpy 2+ (#162)
Browse files Browse the repository at this point in the history
* feat: support numpy 2+

* fix: lints

* fix: support for Python3.8
  • Loading branch information
laszukdawid authored Aug 24, 2024
1 parent 90886e1 commit ec9715e
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 24 deletions.
7 changes: 7 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
LINT_TARGET_DIRS := PyEMD doc example

init:
python -m venv .venv
.venv/bin/pip install -r requirements.txt
.venv/bin/pip install -e .[dev]
@echo "Run 'source .venv/bin/activate' to activate the virtual environment"

test:
python -m PyEMD.tests.test_all

Expand All @@ -12,6 +18,7 @@ doc:

format:
python -m black $(LINT_TARGET_DIRS)
python -m isort PyEMD

lint-check:
python -m isort --check PyEMD
Expand Down
12 changes: 6 additions & 6 deletions PyEMD/EMD.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from scipy.interpolate import interp1d

from PyEMD.splines import akima, cubic, cubic_hermite, cubic_spline_3pts, pchip
from PyEMD.utils import get_timeline
from PyEMD.utils import deduce_common_type, get_timeline

FindExtremaOutput = Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]

Expand Down Expand Up @@ -199,14 +199,14 @@ def prepare_points(
Position (1st row) and values (2nd row) of maxima.
"""
if self.extrema_detection == "parabol":
return self._prepare_points_parabol(T, S, max_pos, max_val, min_pos, min_val)
return self.prepare_points_parabol(T, S, max_pos, max_val, min_pos, min_val)
elif self.extrema_detection == "simple":
return self._prepare_points_simple(T, S, max_pos, max_val, min_pos, min_val)
return self.prepare_points_simple(T, S, max_pos, max_val, min_pos, min_val)
else:
msg = "Incorrect extrema detection type. Please try: 'simple' or 'parabol'."
raise ValueError(msg)

def _prepare_points_parabol(self, T, S, max_pos, max_val, min_pos, min_val) -> Tuple[np.ndarray, np.ndarray]:
def prepare_points_parabol(self, T, S, max_pos, max_val, min_pos, min_val) -> Tuple[np.ndarray, np.ndarray]:
"""
Performs mirroring on signal which extrema do not necessarily
belong on the position array.
Expand Down Expand Up @@ -324,7 +324,7 @@ def _prepare_points_parabol(self, T, S, max_pos, max_val, min_pos, min_val) -> T

return max_extrema, min_extrema

def _prepare_points_simple(
def prepare_points_simple(
self,
T: np.ndarray,
S: np.ndarray,
Expand Down Expand Up @@ -765,7 +765,7 @@ def check_imf(
@staticmethod
def _common_dtype(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Casts inputs (x, y) into a common numpy DTYPE."""
dtype = np.find_common_type([x.dtype, y.dtype], [])
dtype = deduce_common_type(x.dtype, y.dtype)
if x.dtype != dtype:
x = x.astype(dtype)
if y.dtype != dtype:
Expand Down
8 changes: 4 additions & 4 deletions PyEMD/tests/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ def test_whitenoise_check_rescaling_imf(self):

def test_whitenoise_check_nan_values(self):
"""whitenoise check with nan in IMF."""
S = np.array([np.full(100, np.NaN) for i in range(5, 0, -1)])
S = np.array([np.full(100, np.nan) for i in range(5, 0, -1)])
res = whitenoise_check(S)
self.assertEqual(res, None, "Input NaN returns None")
self.assertEqual(res, None, "Input nan returns None")

def test_invalid_alpha(self):
"""Test if invalid alpha return AssertionError."""
S = np.array([np.full(100, np.NaN) for i in range(5, 0, -1)])
S = np.array([np.full(100, np.nan) for i in range(5, 0, -1)])
self.assertRaises(AssertionError, whitenoise_check, S, alpha=1)
self.assertRaises(AssertionError, whitenoise_check, S, alpha=0)
self.assertRaises(AssertionError, whitenoise_check, S, alpha=-10)
Expand All @@ -99,7 +99,7 @@ def test_invalid_test_name(self):

def test_invalid_input_type(self):
"""Test if invalid input type return AssertionError."""
S = [np.full(100, np.NaN) for i in range(5, 0, -1)]
S = [np.full(100, np.nan) for i in range(5, 0, -1)]
self.assertRaises(AssertionError, whitenoise_check, S)
self.assertRaises(AssertionError, whitenoise_check, 1)
self.assertRaises(AssertionError, whitenoise_check, 1.2)
Expand Down
15 changes: 6 additions & 9 deletions PyEMD/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from PyEMD.utils import get_timeline
from PyEMD.utils import deduce_common_type, get_timeline


class MyTestCase(unittest.TestCase):
Expand Down Expand Up @@ -31,14 +31,11 @@ def test_get_timeline_does_not_overflow_int16(self):
self.assertEqual(T[-1], len(S) - 1, "Range is kept")
self.assertEqual(T.dtype, np.uint16, "UInt16 is the min type that matches requirements")

def test_get_timeline_does_not_overflow_float16(self):
S = np.random.random(int(np.finfo(np.float16).max) + 5).astype(dtype=np.float16)
T = get_timeline(len(S), dtype=S.dtype)

self.assertGreater(len(S), np.finfo(S.dtype).max, "Length of the signal is greater than its type max value")
self.assertEqual(len(T), len(S), "Lengths must be equal")
self.assertEqual(T[-1], len(S) - 1, "Range is kept")
self.assertEqual(T.dtype, np.float32, "Float32 is the min type that matches requirements")
def test_deduce_common_types(self):
self.assertEqual(deduce_common_type(np.int16, np.int32), np.int32)
self.assertEqual(deduce_common_type(np.int32, np.int16), np.int32)
self.assertEqual(deduce_common_type(np.int32, np.int32), np.int32)
self.assertEqual(deduce_common_type(np.float32, np.float64), np.float64)


if __name__ == "__main__":
Expand Down
10 changes: 5 additions & 5 deletions PyEMD/tests/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def test_instantiation2(self):
emd.emd(S, t)
imfs, res = emd.get_imfs_and_residue()
vis = Visualisation(emd)
self.assertTrue(np.alltrue(vis.imfs == imfs))
self.assertTrue(np.alltrue(vis.residue == res))
self.assertTrue(np.all(vis.imfs == imfs))
self.assertTrue(np.all(vis.residue == res))

def test_check_imfs(self):
vis = Visualisation()
Expand All @@ -40,7 +40,7 @@ def test_check_imfs3(self):

out_imfs, out_res = vis._check_imfs(imfs, None, False)

self.assertTrue(np.alltrue(imfs == out_imfs))
self.assertTrue(np.all(imfs == out_imfs))
self.assertIsNone(out_res)

def test_check_imfs4(self):
Expand All @@ -57,8 +57,8 @@ def test_check_imfs5(self):
imfs, res = emd.get_imfs_and_residue()
vis = Visualisation(emd)
imfs2, res2 = vis._check_imfs(imfs, res, False)
self.assertTrue(np.alltrue(imfs == imfs2))
self.assertTrue(np.alltrue(res == res2))
self.assertTrue(np.all(imfs == imfs2))
self.assertTrue(np.all(res == res2))

def test_plot_imfs(self):
vis = Visualisation()
Expand Down
17 changes: 17 additions & 0 deletions PyEMD/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import sys
from typing import Optional

import numpy as np

if sys.version_info >= (3, 9):
from functools import cache
else:
from functools import lru_cache as cache


def get_timeline(range_max: int, dtype: Optional[np.dtype] = None) -> np.ndarray:
"""Returns timeline array for requirements.
Expand Down Expand Up @@ -50,3 +56,14 @@ def smallest_inclusive_dtype(ref_dtype: np.dtype, ref_value) -> np.dtype:
raise ValueError("Requested too large integer range. Exceeds max( float64 ) == '{}.".format(max_val))

raise ValueError("Unsupported dtype '{}'. Only intX and floatX are supported.".format(ref_dtype))


@cache
def deduce_common_type(xtype: np.dtype, ytype: np.dtype) -> np.dtype:
if xtype == ytype:
return xtype
if np.version.version[0] == "1":
dtype = np.find_common_type([xtype, ytype], [])
else:
dtype = np.promote_types(xtype, ytype)
return dtype
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
numpy>=1.12
scipy>=0.19
pathos>=0.2.1
tqdm>=4.64.0,<5.0

0 comments on commit ec9715e

Please sign in to comment.