diff --git a/src/probnum/core/random_variables/_dirac.py b/src/probnum/core/random_variables/_dirac.py index 9d82d13329..1b54627a01 100644 --- a/src/probnum/core/random_variables/_dirac.py +++ b/src/probnum/core/random_variables/_dirac.py @@ -10,7 +10,7 @@ _ValueType = TypeVar("ValueType") -class Dirac(_random_variable.RandomVariable[_ValueType]): +class Dirac(_random_variable.DiscreteRandomVariable[_ValueType]): """ The Dirac delta distribution. @@ -53,12 +53,13 @@ def __init__( self._support = support super().__init__( - shape=support.shape, - dtype=support.dtype, + shape=self._support.shape, + dtype=self._support.dtype, random_state=random_state, - parameters={"support": support}, + parameters={"support": self._support}, sample=self._sample, - in_support=lambda x: np.all(x == support), + in_support=lambda x: np.all(x == self._support), + pmf=lambda x: 1.0 if np.all(x == self._support) else 0.0, cdf=lambda x: 0.0 if np.any(x < self._support) else 0.0, mode=lambda: self._support, median=lambda: self._support, diff --git a/src/probnum/core/random_variables/_normal.py b/src/probnum/core/random_variables/_normal.py index 6a2110ae96..5a696696d9 100644 --- a/src/probnum/core/random_variables/_normal.py +++ b/src/probnum/core/random_variables/_normal.py @@ -12,7 +12,7 @@ _ValueType = Union[np.floating, np.ndarray, linops.LinearOperator] -class Normal(_random_variable.RandomVariable[np.ndarray]): +class Normal(_random_variable.ContinuousRandomVariable[np.ndarray]): """ The normal distribution. diff --git a/src/probnum/core/random_variables/_random_variable.py b/src/probnum/core/random_variables/_random_variable.py index 29b0d53c0c..46219467df 100644 --- a/src/probnum/core/random_variables/_random_variable.py +++ b/src/probnum/core/random_variables/_random_variable.py @@ -695,6 +695,74 @@ def __rpow__(self, other) -> "RandomVariable": return NotImplemented +class DiscreteRandomVariable(RandomVariable[_ValueType]): + def __init__( + self, + shape: ShapeArgType, + dtype: np.dtype, + random_state: Optional[RandomStateType] = None, + parameters: Optional[Dict[str, Any]] = None, + sample: Optional[Callable[[ShapeArgType], _ValueType]] = None, + in_support: Optional[Callable[[_ValueType], bool]] = None, + pmf: Optional[Callable[[_ValueType], np.float64]] = None, + logpmf: Optional[Callable[[_ValueType], np.float64]] = None, + cdf: Optional[Callable[[_ValueType], np.float64]] = None, + logcdf: Optional[Callable[[_ValueType], np.float64]] = None, + quantile: Optional[Callable[[FloatArgType], _ValueType]] = None, + mode: Optional[Callable[[], _ValueType]] = None, + median: Optional[Callable[[], _ValueType]] = None, + mean: Optional[Callable[[], _ValueType]] = None, + cov: Optional[Callable[[], _ValueType]] = None, + var: Optional[Callable[[], _ValueType]] = None, + std: Optional[Callable[[], _ValueType]] = None, + entropy: Optional[Callable[[], np.float64]] = None, + ): + # Probability mass function + self.__pmf = pmf + self.__logpmf = logpmf + + super().__init__( + shape=shape, + dtype=dtype, + random_state=random_state, + parameters=parameters, + sample=sample, + in_support=in_support, + pdf=pmf, + logpdf=logpmf, + cdf=cdf, + logcdf=logcdf, + quantile=quantile, + mode=mode, + median=median, + mean=mean, + cov=cov, + var=var, + std=std, + entropy=entropy, + ) + + def pmf(self, x): + if self.__pmf is not None: + return self.__pmf(x) + elif self.__logpmf is not None: + return np.exp(self.__logpmf(x)) + else: + raise NotImplementedError + + def logpmf(self, x): + if self.__logpmf is not None: + return self.__logpmf(x) + elif self.__pmf is not None: + return np.log(self.__pmf(x)) + else: + raise NotImplementedError + + +class ContinuousRandomVariable(RandomVariable[_ValueType]): + pass + + def asrandvar(obj) -> RandomVariable: """ Return ``obj`` as a :class:`RandomVariable`. @@ -796,16 +864,8 @@ def _scipystats_to_rv( return rvs.Normal( mean=scipyrv.mean, cov=scipyrv.cov, random_state=scipyrv.random_state, ) - # Generic distributions - if ( - hasattr(scipyrv, "dist") and isinstance(scipyrv.dist, scipy.stats.rv_discrete) - ) or hasattr(scipyrv, "pmf"): - pdf = getattr(scipyrv, "pmf", None) - logpdf = getattr(scipyrv, "logpmf", None) - else: - pdf = getattr(scipyrv, "pdf", None) - logpdf = getattr(scipyrv, "logpdf", None) + # Generic distributions def _wrap_np_scalar(fn): if fn is None: return None @@ -820,17 +880,40 @@ def _wrapper(*args, **kwargs): return _wrapper + if ( + hasattr(scipyrv, "dist") and isinstance(scipyrv.dist, scipy.stats.rv_discrete) + ) or hasattr(scipyrv, "pmf"): + rv_subclass = DiscreteRandomVariable + rv_subclass_kwargs = { + "pmf": _wrap_np_scalar(getattr(scipyrv, "pmf", None)), + "logpmf": _wrap_np_scalar(getattr(scipyrv, "logpmf", None)), + } + else: + rv_subclass = ContinuousRandomVariable + rv_subclass_kwargs = { + "pdf": _wrap_np_scalar(getattr(scipyrv, "pdf", None)), + "logpdf": _wrap_np_scalar(getattr(scipyrv, "logpdf", None)), + } + + if isinstance(scipyrv, scipy.stats._distn_infrastructure.rv_frozen): + + def in_support(x): + low, high = scipyrv.support() + + return bool(low <= x <= high) + + else: + in_support = None + # Infer shape and dtype sample = _wrap_np_scalar(scipyrv.rvs)() - return RandomVariable( + return rv_subclass( shape=sample.shape, dtype=sample.dtype, random_state=getattr(scipyrv, "random_state", None), sample=_wrap_np_scalar(getattr(scipyrv, "rvs", None)), - in_support=None, # TODO for univariate - pdf=_wrap_np_scalar(pdf), - logpdf=_wrap_np_scalar(logpdf), + in_support=in_support, cdf=_wrap_np_scalar(getattr(scipyrv, "cdf", None)), logcdf=_wrap_np_scalar(getattr(scipyrv, "logcdf", None)), quantile=_wrap_np_scalar(getattr(scipyrv, "ppf", None)), @@ -841,4 +924,5 @@ def _wrapper(*args, **kwargs): var=_wrap_np_scalar(getattr(scipyrv, "var", None)), std=_wrap_np_scalar(getattr(scipyrv, "std", None)), entropy=_wrap_np_scalar(getattr(scipyrv, "entropy", None)), + **rv_subclass_kwargs, )