Skip to content

Commit

Permalink
test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
pravirkr committed May 26, 2024
1 parent 6558152 commit 9d6601f
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 6 deletions.
2 changes: 1 addition & 1 deletion kalman_detector/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def kalman_filter(
-----
Number of changes is sqrt(nchan)*sig_eta/mean(spec_std).
Frequency scale is 1/sig_eta**2.
For details, see Eq. 10--12 in Kumar, Zackay & Law (2023).
For details, see Eq. 10--12 in Kumar, Zackay & Law (2024).
"""
if v0 is None:
v0 = np.median(spec_std) ** 2
Expand Down
14 changes: 9 additions & 5 deletions kalman_detector/svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

logger = logging.getLogger(__name__)


def collect2(expr: sp.Expr, v1: sp.Symbol, p1: int, v2: sp.Symbol, p2: int) -> sp.Expr:
"""Collect the coefficient of v1**p1 and v2**p2 in expr.
Expand Down Expand Up @@ -482,8 +483,8 @@ def kalman_binary_compress(
spec: np.ndarray,
spec_std: np.ndarray,
sig_t: float,
e0: float,
v0: float,
e0: float = 0,
v0: float | None = None,
) -> State:
"""Kalman binary compression for a spectrum.
Expand All @@ -496,15 +497,18 @@ def kalman_binary_compress(
sig_t : float
Standard deviation of the tranition between states.
e0 : float
Expected value of the first hidden state A0.
v0 : float
Variance of the first hidden state A0.
Initial guess of the expected value of the first hidden state A0, by default 0.
v0 : float, optional
Initial guess of the variance of the first hidden state A0, by default None.
Returns
-------
State
Final state for the whole spectrum.
"""
if v0 is None:
v0 = np.median(spec_std) ** 2

var_d = spec_std**2
var_t = sig_t**2
states = [
Expand Down
16 changes: 16 additions & 0 deletions tests/test_kalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,28 @@
from numpy.polynomial import Polynomial

from kalman_detector import utils
from kalman_detector.core import kalman_filter
from kalman_detector.main import (
KalmanDetector,
KalmanDistribution,
secondary_spectrum_cumulative_chi2_score,
)


class TestKalmanFilter:
def test_kalman_filter(self) -> None:
nchans = 128
target = 5
rng = np.random.default_rng()
spec_std = rng.normal(1, 0.1, size=nchans)
spec = rng.normal(target, spec_std)
score = kalman_filter(spec, spec_std, 0.1)
mask = np.zeros(nchans, dtype=bool)
mask[rng.choice(nchans, 2, replace=False)] = True
score_masked = kalman_filter(spec, spec_std, 0.1, chan_mask=mask)
np.testing.assert_allclose(score, score_masked, rtol=1e-1)


class TestKalmanDetector:
def test_q_par_float(self) -> None:
q_par = 0.1
Expand Down Expand Up @@ -129,6 +144,7 @@ def test_str(self) -> None:
sigma_arr = np.arange(0.1, 1, 0.01)
kdist = KalmanDistribution(sigma_arr, 0.01, ntrials=1000)
assert str(kdist).startswith("KalmanDistribution")
assert repr(kdist) == str(kdist)


class TestSecondarySpectrym:
Expand Down
9 changes: 9 additions & 0 deletions tests/test_kalman_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def test_consistency_v0(self, v0: float, nchans: int) -> None:
kalman1d = kalman_filter(spec, spec_std, sig_t, e0=e0, v0=v0)
kalman2d = kalman_binary_hypothesis(spec, spec_std, sig_t, e0=e0, v0=v0)
np.testing.assert_almost_equal(kalman1d, kalman2d, decimal=10)
kalman1d_py = kalman_filter.py_func(spec, spec_std, sig_t, e0=e0, v0=v0)
np.testing.assert_almost_equal(kalman1d, kalman1d_py, decimal=10)

@pytest.mark.parametrize("e0", [-10, -1, -0.1, -0.01, 0, 0.01, 0.1, 1, 10])
@pytest.mark.parametrize("nchans", [2, 128, 4096])
Expand All @@ -31,6 +33,9 @@ def test_consistency_e0(self, e0: float, nchans: int) -> None:
kalman1d = kalman_filter(spec, spec_std, sig_t, e0=e0, v0=v0)
kalman2d = kalman_binary_hypothesis(spec, spec_std, sig_t, e0=e0, v0=v0)
np.testing.assert_almost_equal(kalman2d, kalman1d, decimal=10)
# Test without v0
kalman1d_py = kalman_filter.py_func(spec, spec_std, sig_t, e0=e0)
np.testing.assert_almost_equal(kalman1d, kalman1d_py, decimal=10)

@pytest.mark.parametrize("sig_t", [0.01, 0.1, 1, 10, 100, 1000])
@pytest.mark.parametrize("nchans", [2, 128, 1024])
Expand All @@ -44,6 +49,8 @@ def test_consistency_eta(self, sig_t: float, nchans: int) -> None:
kalman1d = kalman_filter(spec, spec_std, sig_t, e0=e0, v0=v0)
kalman2d = kalman_binary_hypothesis(spec, spec_std, sig_t, e0=e0, v0=v0)
np.testing.assert_almost_equal(kalman2d, kalman1d, decimal=10)
kalman1d_py = kalman_filter.py_func(spec, spec_std, sig_t, e0=e0, v0=v0)
np.testing.assert_almost_equal(kalman1d, kalman1d_py, decimal=10)

@pytest.mark.parametrize("nchans", [2, 4, 16, 64, 256, 1024, 4096, 8192])
def test_consistency_nchans(self, nchans: int) -> None:
Expand All @@ -57,3 +64,5 @@ def test_consistency_nchans(self, nchans: int) -> None:
kalman1d = kalman_filter(spec, spec_std, sig_t, e0=e0, v0=v0)
kalman2d = kalman_binary_hypothesis(spec, spec_std, sig_t, e0=e0, v0=v0)
np.testing.assert_almost_equal(kalman2d, kalman1d, decimal=10)
kalman1d_py = kalman_filter.py_func(spec, spec_std, sig_t, e0=e0, v0=v0)
np.testing.assert_almost_equal(kalman1d, kalman1d_py, decimal=10)

0 comments on commit 9d6601f

Please sign in to comment.