Skip to content

Commit

Permalink
fix: typo in EMD_matlab (#166)
Browse files Browse the repository at this point in the history
* fix: typo in EMD_matlab

* fix lint
  • Loading branch information
laszukdawid authored Sep 11, 2024
1 parent 9ca0bae commit 4fc4001
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 3 deletions.
14 changes: 12 additions & 2 deletions PyEMD/EMD_matlab.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,18 @@ def emd(self, S, T=None, maxImf=None):
The decomposition is limited to maxImf imf. No limitation as default.
Returns IMF functions in dic format. IMF = {0:imf0, 1:imf1...}.
*Note*: First argument `self` should be an instance of EMD class.
It should be resolved in future versions.
For example:
```
emd = EMD()
emd.emd(emd, S, T, maxImf)
```
Input:
---------
self: Instance of EMD class.
S: Signal.
T: Positions of signal. If none passed numpy arange is created.
maxImf: IMF number to which decomposition should be performed.
Expand All @@ -457,7 +467,7 @@ def emd(self, S, T=None, maxImf=None):
maxImf = -1

# Make sure same types are dealt
S, T = unify_type(S, T)
S, T = unify_types(S, T)
self.DTYPE = S.dtype

Res = S.astype(self.DTYPE)
Expand All @@ -479,7 +489,7 @@ def emd(self, S, T=None, maxImf=None):

if S.shape != T.shape:
info = "Time array should be the same size as signal."
raise Exception(info)
raise ValueError(info)

# Create arrays
IMF = {} # Dic for imfs signals
Expand Down
2 changes: 1 addition & 1 deletion PyEMD/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

__version__ = "1.6.3"
__version__ = "1.6.4"
logger = logging.getLogger("pyemd")

from PyEMD.CEEMDAN import CEEMDAN # noqa
Expand Down
49 changes: 49 additions & 0 deletions PyEMD/tests/test_emd_matlab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import unittest

import numpy as np

from PyEMD.EMD_matlab import EMD


class EMDMatlabTest(unittest.TestCase):
@staticmethod
def test_default_call_EMD():
T = np.arange(0, 1, 0.01)
S = np.cos(2 * T * 2 * np.pi)
max_imf = 2

emd = EMD()
emd.emd(emd, S, T, max_imf)

def test_different_length_input(self):
T = np.arange(20)
S = np.random.random(len(T) + 7)

emd = EMD()
with self.assertRaises(ValueError):
emd.emd(emd, S, T)

def test_trend(self):
"""
Input is trend. Expeting no shifting process.
"""
emd = EMD()

T = np.arange(0, 1, 0.01)
S = np.cos(2 * T * 2 * np.pi)

# Input - linear function f(t) = 2*t
output = emd.emd(emd, S, T)
self.assertEqual(len(output), 4, "Expecting 4 outputs - IMF, EXT, ITER, imfNo")

IMF, EXT, ITER, imfNo = output
self.assertEqual(len(IMF), 2, "Expecting single IMF + residue")
self.assertEqual(len(IMF[0]), len(S), "Expecting single IMF")
self.assertTrue(np.allclose(S, IMF[0]))
self.assertLessEqual(ITER[0], 5, "Expecting 5 iterations at most")
self.assertEqual(imfNo, 2, "Expecting 1 IMF")
self.assertEqual(EXT[0], 3, "Expecting single EXT")


if __name__ == "__main__":
unittest.main()

0 comments on commit 4fc4001

Please sign in to comment.