Skip to content

Commit

Permalink
fix: lints
Browse files Browse the repository at this point in the history
  • Loading branch information
laszukdawid committed Aug 24, 2024
1 parent 0cd05da commit 5e16252
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 14 deletions.
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ LINT_TARGET_DIRS := PyEMD doc example
init:
python -m venv .venv
.venv/bin/pip install -r requirements.txt
.venv/bin/pip install -e .[dev]
@echo "Run 'source .venv/bin/activate' to activate the virtual environment"

test:
Expand All @@ -17,6 +18,7 @@ doc:

format:
python -m black $(LINT_TARGET_DIRS)
python -m isort PyEMD

lint-check:
python -m isort --check PyEMD
Expand Down
10 changes: 5 additions & 5 deletions PyEMD/EMD.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from scipy.interpolate import interp1d

from PyEMD.splines import akima, cubic, cubic_hermite, cubic_spline_3pts, pchip
from PyEMD.utils import get_timeline, deduce_common_type
from PyEMD.utils import deduce_common_type, get_timeline

FindExtremaOutput = Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]

Expand Down Expand Up @@ -199,14 +199,14 @@ def prepare_points(
Position (1st row) and values (2nd row) of maxima.
"""
if self.extrema_detection == "parabol":
return self._prepare_points_parabol(T, S, max_pos, max_val, min_pos, min_val)
return self.prepare_points_parabol(T, S, max_pos, max_val, min_pos, min_val)
elif self.extrema_detection == "simple":
return self._prepare_points_simple(T, S, max_pos, max_val, min_pos, min_val)
return self.prepare_points_simple(T, S, max_pos, max_val, min_pos, min_val)
else:
msg = "Incorrect extrema detection type. Please try: 'simple' or 'parabol'."
raise ValueError(msg)

def _prepare_points_parabol(self, T, S, max_pos, max_val, min_pos, min_val) -> Tuple[np.ndarray, np.ndarray]:
def prepare_points_parabol(self, T, S, max_pos, max_val, min_pos, min_val) -> Tuple[np.ndarray, np.ndarray]:
"""
Performs mirroring on signal which extrema do not necessarily
belong on the position array.
Expand Down Expand Up @@ -324,7 +324,7 @@ def _prepare_points_parabol(self, T, S, max_pos, max_val, min_pos, min_val) -> T

return max_extrema, min_extrema

def _prepare_points_simple(
def prepare_points_simple(
self,
T: np.ndarray,
S: np.ndarray,
Expand Down
3 changes: 2 additions & 1 deletion PyEMD/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from PyEMD.utils import get_timeline, deduce_common_type
from PyEMD.utils import deduce_common_type, get_timeline


class MyTestCase(unittest.TestCase):
Expand Down Expand Up @@ -37,5 +37,6 @@ def test_deduce_common_types(self):
self.assertEqual(deduce_common_type(np.int32, np.int32), np.int32)
self.assertEqual(deduce_common_type(np.float32, np.float64), np.float64)


if __name__ == "__main__":
unittest.main()
10 changes: 5 additions & 5 deletions PyEMD/tests/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def test_instantiation2(self):
emd.emd(S, t)
imfs, res = emd.get_imfs_and_residue()
vis = Visualisation(emd)
self.assertTrue(np.alltrue(vis.imfs == imfs))
self.assertTrue(np.alltrue(vis.residue == res))
self.assertTrue(np.all(vis.imfs == imfs))
self.assertTrue(np.all(vis.residue == res))

def test_check_imfs(self):
vis = Visualisation()
Expand All @@ -40,7 +40,7 @@ def test_check_imfs3(self):

out_imfs, out_res = vis._check_imfs(imfs, None, False)

self.assertTrue(np.alltrue(imfs == out_imfs))
self.assertTrue(np.all(imfs == out_imfs))
self.assertIsNone(out_res)

def test_check_imfs4(self):
Expand All @@ -57,8 +57,8 @@ def test_check_imfs5(self):
imfs, res = emd.get_imfs_and_residue()
vis = Visualisation(emd)
imfs2, res2 = vis._check_imfs(imfs, res, False)
self.assertTrue(np.alltrue(imfs == imfs2))
self.assertTrue(np.alltrue(res == res2))
self.assertTrue(np.all(imfs == imfs2))
self.assertTrue(np.all(res == res2))

def test_plot_imfs(self):
vis = Visualisation()
Expand Down
6 changes: 3 additions & 3 deletions PyEMD/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Optional
from functools import cache
from typing import Optional

import numpy as np

Expand Down Expand Up @@ -52,13 +52,13 @@ def smallest_inclusive_dtype(ref_dtype: np.dtype, ref_value) -> np.dtype:

raise ValueError("Unsupported dtype '{}'. Only intX and floatX are supported.".format(ref_dtype))


@cache
def deduce_common_type(xtype: np.dtype, ytype: np.dtype) -> np.dtype:
if xtype == ytype:
return xtype
if np.version.version[0] == '1':
if np.version.version[0] == "1":
dtype = np.find_common_type([xtype, ytype], [])
else:
dtype = np.promote_types(xtype, ytype)
return dtype

0 comments on commit 5e16252

Please sign in to comment.