Skip to content

Commit

Permalink
fix(foreground.poisson): fix cumulative_trapezoid call (#60)
Browse files Browse the repository at this point in the history
* fix(foreground.poisson): fix cumulative_trapezoid call

The first parameter is y (the values to intregrate), the
second, optional, parameter is x, the axis of integration.

Also use the `initial` parameter to prepend a zero the result,
removing the need to pre-initialise the output array.

* tests
  • Loading branch information
ketiltrout authored Aug 15, 2024
1 parent f0e5334 commit 488cf0b
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 5 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,7 @@ nosetests.xml
.mr.developer.cfg
.project
.pydevproject

#Editor rubble
*~
.*.sw*
3 changes: 1 addition & 2 deletions cora/foreground/poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,7 @@ def inhomogeneous_process_approx(t, rate):
ts = np.linspace(0.0, t, 10000)
rs = rate(ts)

cumr = np.zeros_like(ts)
cumr[1:] = cumulative_trapezoid(ts, rs)
cumr = cumulative_trapezoid(rs, ts, initial=0)
cumr /= cumr[-1]

# Interpolate to generate the inverse CDF and use this to generate
Expand Down
6 changes: 3 additions & 3 deletions tests/test_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,6 @@ def test_pointsource():
pstd = pol.std(axis=-1)
assert (pstd[:, 0] > 3.0).all()
assert (pstd[:, 0] < 15.0).all()
assert (pol.std(axis=-1)[:, 1:3] > 0.01).all()
assert (pol.std(axis=-1)[:, 1:3] < 0.05).all()
assert (pol.std(axis=-1)[:, 3] == 0.0).all()
assert (pstd[:, 1:3] > 0.005).all()
assert (pstd[:, 1:3] < 0.015).all()
assert (pstd[:, 3] == 0.0).all()
24 changes: 24 additions & 0 deletions tests/test_poisson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Test caput.foreground.poisson"""

import numpy as np

from cora.foreground.poisson import inhomogeneous_process_approx


def test_inhomogeneous_process_approx():
"""Test inhomogeneous_process_approx"""

# the most boringest rate function
rate = lambda s: 1000 * (5 - s)

result = inhomogeneous_process_approx(5, rate)

# Mean should be approximately 1.666
mean = result.mean()
assert mean > 1.6
assert mean < 1.75

# stdev should be approximately 1.2
stdev = result.std()
assert stdev > 1.1
assert stdev < 1.3

0 comments on commit 488cf0b

Please sign in to comment.