Skip to content

Commit

Permalink
Add tests on custom model definition and fitting
Browse files Browse the repository at this point in the history
  • Loading branch information
rhugonnet committed Aug 11, 2023
1 parent 7232c43 commit 5247cd4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
15 changes: 15 additions & 0 deletions skgstat/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,21 @@ def adder(l, a):
for r, c in zip(res, adder([1, 4, 8], 4)):
self.assertEqual(r, c)

def test_sum_spherical(self):
@variogram
def sum_spherical(h, r1, c1, r2, c2, b1=0, b2=0):
return spherical(h, r1, c1, b1) + spherical(h, r2, c2, b2)

# Parameters for the two spherical models
params = [1, 0.3, 10, 0.7]

# Values at which we'll evaluate the function and its expected result
vals = [0, 1, 100]
res = [0, 0.3 + spherical(1, 10, 0.7), 1]

for r, c in zip(res, sum_spherical(vals, *params)):
self.assertEqual(r, c)


if __name__=='__main__':
unittest.main()
15 changes: 14 additions & 1 deletion skgstat/tests/test_variogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from skgstat import OrdinaryKriging
from skgstat import estimators
from skgstat import plotting
from skgstat.models import variogram, spherical, matern


class TestSpatiallyCorrelatedData(unittest.TestCase):
Expand Down Expand Up @@ -61,7 +62,7 @@ def test_sparse_maxlag_30(self):
self.assertAlmostEqual(x, y, places=3)


class TestVariogramInstatiation(unittest.TestCase):
class TestVariogramInstantiation(unittest.TestCase):
def setUp(self):
# set up default values, whenever c and v are not important
np.random.seed(42)
Expand Down Expand Up @@ -949,6 +950,18 @@ def test_implicit_nugget(self):

self.assertTrue(abs(V.parameters[-1] - 2.) < 1e-10)

def test_fit_custom_model(self):

# Define a custom variogram and run the fit
@variogram
def sum_spherical(h, r1, c1, r2, c2, b1, b2):
return spherical(h, r1, c1, b1) + spherical(h, r2, c2, b2)

V = Variogram(self.c, self.v, use_nugget=True, model=sum_spherical)

# Check that 6 parameters were found
assert len(V.cof) == 6


class TestVariogramQualityMeasures(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit 5247cd4

Please sign in to comment.