From 5247cd45aaf1d3a676c3b22e0597ce05e5451348 Mon Sep 17 00:00:00 2001 From: Romain Hugonnet Date: Thu, 10 Aug 2023 16:17:38 -0800 Subject: [PATCH] Add tests on custom model definition and fitting --- skgstat/tests/test_models.py | 15 +++++++++++++++ skgstat/tests/test_variogram.py | 15 ++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/skgstat/tests/test_models.py b/skgstat/tests/test_models.py index 6cac237..4110514 100644 --- a/skgstat/tests/test_models.py +++ b/skgstat/tests/test_models.py @@ -186,6 +186,21 @@ def adder(l, a): for r, c in zip(res, adder([1, 4, 8], 4)): self.assertEqual(r, c) + def test_sum_spherical(self): + @variogram + def sum_spherical(h, r1, c1, r2, c2, b1=0, b2=0): + return spherical(h, r1, c1, b1) + spherical(h, r2, c2, b2) + + # Parameters for the two spherical models + params = [1, 0.3, 10, 0.7] + + # Values at which we'll evaluate the function and its expected result + vals = [0, 1, 100] + res = [0, 0.3 + spherical(1, 10, 0.7), 1] + + for r, c in zip(res, sum_spherical(vals, *params)): + self.assertEqual(r, c) + if __name__=='__main__': unittest.main() diff --git a/skgstat/tests/test_variogram.py b/skgstat/tests/test_variogram.py index 52ce572..a5fdaeb 100644 --- a/skgstat/tests/test_variogram.py +++ b/skgstat/tests/test_variogram.py @@ -19,6 +19,7 @@ from skgstat import OrdinaryKriging from skgstat import estimators from skgstat import plotting +from skgstat.models import variogram, spherical, matern class TestSpatiallyCorrelatedData(unittest.TestCase): @@ -61,7 +62,7 @@ def test_sparse_maxlag_30(self): self.assertAlmostEqual(x, y, places=3) -class TestVariogramInstatiation(unittest.TestCase): +class TestVariogramInstantiation(unittest.TestCase): def setUp(self): # set up default values, whenever c and v are not important np.random.seed(42) @@ -949,6 +950,18 @@ def test_implicit_nugget(self): self.assertTrue(abs(V.parameters[-1] - 2.) < 1e-10) + def test_fit_custom_model(self): + + # Define a custom variogram and run the fit + @variogram + def sum_spherical(h, r1, c1, r2, c2, b1, b2): + return spherical(h, r1, c1, b1) + spherical(h, r2, c2, b2) + + V = Variogram(self.c, self.v, use_nugget=True, model=sum_spherical) + + # Check that 6 parameters were found + assert len(V.cof) == 6 + class TestVariogramQualityMeasures(unittest.TestCase): def setUp(self):