diff --git a/skgstat/Variogram.py b/skgstat/Variogram.py index 1858c70..eb6c88b 100644 --- a/skgstat/Variogram.py +++ b/skgstat/Variogram.py @@ -2,6 +2,7 @@ Variogram class """ import copy +import inspect import warnings from typing import Iterable, Callable, Union, Tuple @@ -335,6 +336,7 @@ def __init__(self, # model can be a function or a string self._model = None + self._custom_model = False self.set_model(model_name=model) # specify if the lag should be given absolute or relative to the maxlag @@ -973,6 +975,7 @@ def set_model(self, model_name): ) % model_name ) else: # pragma: no cover + self._custom_model = True self._model = model_name def _build_harmonized_model(self): @@ -1454,7 +1457,7 @@ def preprocessing(self, force=False): self._calc_diff(force=force) self._calc_groups(force=force) - def fit(self, force=False, method=None, sigma=None, **kwargs): + def fit(self, force=False, method=None, sigma=None, bounds=None, p0=None, **kwargs): """Fit the variogram The fit function will fit the theoretical variogram function to the @@ -1563,18 +1566,34 @@ def fit(self, force=False, method=None, sigma=None, **kwargs): self.cof = [r, s, n] return - # Switch the method - # wrap the model to include or exclude the nugget - if self.use_nugget: - def wrapped(*args): - return self._model(*args) + # For a supported model, wrap the function depending on nugget and get logical bounds + if not self._custom_model: + # Switch the method + # wrap the model to include or exclude the nugget + if self.use_nugget: + def wrapped(*args): + return self._model(*args) + else: + def wrapped(*args): + return self._model(*args, 0) + + # get p0 + if bounds is None: + bounds = (0, self.__get_fit_bounds(x, y)) + if p0 is None: + p0 = np.asarray(bounds[1]) + # Else, inspect the function for the number of arguments else: - def wrapped(*args): - return self._model(*args, 0) + # The number of arguments of argspec minus one is what we initialized + argspec = inspect.getfullargspec(self._model) + nb_args = len(argspec) - 1 + if bounds is None: + bounds = (0, [np.maximum(np.nanmax(x), np.nanmax(y))] * nb_args) + if p0 is None: + p0 = np.asarray(bounds[1]) - # get p0 - bounds = (0, self.__get_fit_bounds(x, y)) - p0 = np.asarray(bounds[1]) + def wrapped(*args): + return self._model(*args) # Trust Region Reflective if self.fit_method == 'trf':