Skip to content

Commit

Permalink
feat: support numpy 2+
Browse files Browse the repository at this point in the history
  • Loading branch information
laszukdawid committed Aug 24, 2024
1 parent 90886e1 commit 0cd05da
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 16 deletions.
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
LINT_TARGET_DIRS := PyEMD doc example

init:
python -m venv .venv
.venv/bin/pip install -r requirements.txt
@echo "Run 'source .venv/bin/activate' to activate the virtual environment"

test:
python -m PyEMD.tests.test_all

Expand Down
4 changes: 2 additions & 2 deletions PyEMD/EMD.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from scipy.interpolate import interp1d

from PyEMD.splines import akima, cubic, cubic_hermite, cubic_spline_3pts, pchip
from PyEMD.utils import get_timeline
from PyEMD.utils import get_timeline, deduce_common_type

FindExtremaOutput = Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]

Expand Down Expand Up @@ -765,7 +765,7 @@ def check_imf(
@staticmethod
def _common_dtype(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Casts inputs (x, y) into a common numpy DTYPE."""
dtype = np.find_common_type([x.dtype, y.dtype], [])
dtype = deduce_common_type(x.dtype, y.dtype)
if x.dtype != dtype:
x = x.astype(dtype)
if y.dtype != dtype:
Expand Down
8 changes: 4 additions & 4 deletions PyEMD/tests/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ def test_whitenoise_check_rescaling_imf(self):

def test_whitenoise_check_nan_values(self):
"""whitenoise check with nan in IMF."""
S = np.array([np.full(100, np.NaN) for i in range(5, 0, -1)])
S = np.array([np.full(100, np.nan) for i in range(5, 0, -1)])
res = whitenoise_check(S)
self.assertEqual(res, None, "Input NaN returns None")
self.assertEqual(res, None, "Input nan returns None")

def test_invalid_alpha(self):
"""Test if invalid alpha return AssertionError."""
S = np.array([np.full(100, np.NaN) for i in range(5, 0, -1)])
S = np.array([np.full(100, np.nan) for i in range(5, 0, -1)])
self.assertRaises(AssertionError, whitenoise_check, S, alpha=1)
self.assertRaises(AssertionError, whitenoise_check, S, alpha=0)
self.assertRaises(AssertionError, whitenoise_check, S, alpha=-10)
Expand All @@ -99,7 +99,7 @@ def test_invalid_test_name(self):

def test_invalid_input_type(self):
"""Test if invalid input type return AssertionError."""
S = [np.full(100, np.NaN) for i in range(5, 0, -1)]
S = [np.full(100, np.nan) for i in range(5, 0, -1)]
self.assertRaises(AssertionError, whitenoise_check, S)
self.assertRaises(AssertionError, whitenoise_check, 1)
self.assertRaises(AssertionError, whitenoise_check, 1.2)
Expand Down
16 changes: 6 additions & 10 deletions PyEMD/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from PyEMD.utils import get_timeline
from PyEMD.utils import get_timeline, deduce_common_type


class MyTestCase(unittest.TestCase):
Expand Down Expand Up @@ -31,15 +31,11 @@ def test_get_timeline_does_not_overflow_int16(self):
self.assertEqual(T[-1], len(S) - 1, "Range is kept")
self.assertEqual(T.dtype, np.uint16, "UInt16 is the min type that matches requirements")

def test_get_timeline_does_not_overflow_float16(self):
S = np.random.random(int(np.finfo(np.float16).max) + 5).astype(dtype=np.float16)
T = get_timeline(len(S), dtype=S.dtype)

self.assertGreater(len(S), np.finfo(S.dtype).max, "Length of the signal is greater than its type max value")
self.assertEqual(len(T), len(S), "Lengths must be equal")
self.assertEqual(T[-1], len(S) - 1, "Range is kept")
self.assertEqual(T.dtype, np.float32, "Float32 is the min type that matches requirements")

def test_deduce_common_types(self):
self.assertEqual(deduce_common_type(np.int16, np.int32), np.int32)
self.assertEqual(deduce_common_type(np.int32, np.int16), np.int32)
self.assertEqual(deduce_common_type(np.int32, np.int32), np.int32)
self.assertEqual(deduce_common_type(np.float32, np.float64), np.float64)

if __name__ == "__main__":
unittest.main()
12 changes: 12 additions & 0 deletions PyEMD/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional
from functools import cache

import numpy as np

Expand Down Expand Up @@ -50,3 +51,14 @@ def smallest_inclusive_dtype(ref_dtype: np.dtype, ref_value) -> np.dtype:
raise ValueError("Requested too large integer range. Exceeds max( float64 ) == '{}.".format(max_val))

raise ValueError("Unsupported dtype '{}'. Only intX and floatX are supported.".format(ref_dtype))

@cache
def deduce_common_type(xtype: np.dtype, ytype: np.dtype) -> np.dtype:
if xtype == ytype:
return xtype
if np.version.version[0] == '1':
dtype = np.find_common_type([xtype, ytype], [])
else:
dtype = np.promote_types(xtype, ytype)
return dtype

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
numpy>=1.12
scipy>=0.19
pathos>=0.2.1
tqdm>=4.64.0,<5.0

0 comments on commit 0cd05da

Please sign in to comment.