Skip to content

Commit

Permalink
Add _custom_model class attribute, exception in fit and p0 and bounds…
Browse files Browse the repository at this point in the history
… as parameters
  • Loading branch information
rhugonnet committed Aug 11, 2023
1 parent 01e13fe commit 7232c43
Showing 1 changed file with 30 additions and 11 deletions.
41 changes: 30 additions & 11 deletions skgstat/Variogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Variogram class
"""
import copy
import inspect
import warnings
from typing import Iterable, Callable, Union, Tuple

Expand Down Expand Up @@ -335,6 +336,7 @@ def __init__(self,

# model can be a function or a string
self._model = None
self._custom_model = False
self.set_model(model_name=model)

# specify if the lag should be given absolute or relative to the maxlag
Expand Down Expand Up @@ -973,6 +975,7 @@ def set_model(self, model_name):
) % model_name
)
else: # pragma: no cover
self._custom_model = True
self._model = model_name

def _build_harmonized_model(self):
Expand Down Expand Up @@ -1454,7 +1457,7 @@ def preprocessing(self, force=False):
self._calc_diff(force=force)
self._calc_groups(force=force)

def fit(self, force=False, method=None, sigma=None, **kwargs):
def fit(self, force=False, method=None, sigma=None, bounds=None, p0=None, **kwargs):
"""Fit the variogram
The fit function will fit the theoretical variogram function to the
Expand Down Expand Up @@ -1563,18 +1566,34 @@ def fit(self, force=False, method=None, sigma=None, **kwargs):
self.cof = [r, s, n]
return

# Switch the method
# wrap the model to include or exclude the nugget
if self.use_nugget:
def wrapped(*args):
return self._model(*args)
# For a supported model, wrap the function depending on nugget and get logical bounds
if not self._custom_model:
# Switch the method
# wrap the model to include or exclude the nugget
if self.use_nugget:
def wrapped(*args):
return self._model(*args)
else:
def wrapped(*args):
return self._model(*args, 0)

# get p0
if bounds is None:
bounds = (0, self.__get_fit_bounds(x, y))
if p0 is None:
p0 = np.asarray(bounds[1])
# Else, inspect the function for the number of arguments
else:
def wrapped(*args):
return self._model(*args, 0)
# The number of arguments of argspec minus one is what we initialized
argspec = inspect.getfullargspec(self._model)
nb_args = len(argspec) - 1
if bounds is None:
bounds = (0, [np.maximum(np.nanmax(x), np.nanmax(y))] * nb_args)
if p0 is None:
p0 = np.asarray(bounds[1])

# get p0
bounds = (0, self.__get_fit_bounds(x, y))
p0 = np.asarray(bounds[1])
def wrapped(*args):
return self._model(*args)

# Trust Region Reflective
if self.fit_method == 'trf':
Expand Down

0 comments on commit 7232c43

Please sign in to comment.