Skip to content

Commit

Permalink
Add subclasses for continuous/discrete random variables
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinpfoertner committed Aug 24, 2020
1 parent b0c91f8 commit d6f1911
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 19 deletions.
11 changes: 6 additions & 5 deletions src/probnum/core/random_variables/_dirac.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
_ValueType = TypeVar("ValueType")


class Dirac(_random_variable.RandomVariable[_ValueType]):
class Dirac(_random_variable.DiscreteRandomVariable[_ValueType]):
"""
The Dirac delta distribution.
Expand Down Expand Up @@ -53,12 +53,13 @@ def __init__(
self._support = support

super().__init__(
shape=support.shape,
dtype=support.dtype,
shape=self._support.shape,
dtype=self._support.dtype,
random_state=random_state,
parameters={"support": support},
parameters={"support": self._support},
sample=self._sample,
in_support=lambda x: np.all(x == support),
in_support=lambda x: np.all(x == self._support),
pmf=lambda x: 1.0 if np.all(x == self._support) else 0.0,
cdf=lambda x: 0.0 if np.any(x < self._support) else 0.0,
mode=lambda: self._support,
median=lambda: self._support,
Expand Down
2 changes: 1 addition & 1 deletion src/probnum/core/random_variables/_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
_ValueType = Union[np.floating, np.ndarray, linops.LinearOperator]


class Normal(_random_variable.RandomVariable[np.ndarray]):
class Normal(_random_variable.ContinuousRandomVariable[np.ndarray]):
"""
The normal distribution.
Expand Down
110 changes: 97 additions & 13 deletions src/probnum/core/random_variables/_random_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,74 @@ def __rpow__(self, other) -> "RandomVariable":
return NotImplemented


class DiscreteRandomVariable(RandomVariable[_ValueType]):
def __init__(
self,
shape: ShapeArgType,
dtype: np.dtype,
random_state: Optional[RandomStateType] = None,
parameters: Optional[Dict[str, Any]] = None,
sample: Optional[Callable[[ShapeArgType], _ValueType]] = None,
in_support: Optional[Callable[[_ValueType], bool]] = None,
pmf: Optional[Callable[[_ValueType], np.float64]] = None,
logpmf: Optional[Callable[[_ValueType], np.float64]] = None,
cdf: Optional[Callable[[_ValueType], np.float64]] = None,
logcdf: Optional[Callable[[_ValueType], np.float64]] = None,
quantile: Optional[Callable[[FloatArgType], _ValueType]] = None,
mode: Optional[Callable[[], _ValueType]] = None,
median: Optional[Callable[[], _ValueType]] = None,
mean: Optional[Callable[[], _ValueType]] = None,
cov: Optional[Callable[[], _ValueType]] = None,
var: Optional[Callable[[], _ValueType]] = None,
std: Optional[Callable[[], _ValueType]] = None,
entropy: Optional[Callable[[], np.float64]] = None,
):
# Probability mass function
self.__pmf = pmf
self.__logpmf = logpmf

super().__init__(
shape=shape,
dtype=dtype,
random_state=random_state,
parameters=parameters,
sample=sample,
in_support=in_support,
pdf=pmf,
logpdf=logpmf,
cdf=cdf,
logcdf=logcdf,
quantile=quantile,
mode=mode,
median=median,
mean=mean,
cov=cov,
var=var,
std=std,
entropy=entropy,
)

def pmf(self, x):
if self.__pmf is not None:
return self.__pmf(x)
elif self.__logpmf is not None:
return np.exp(self.__logpmf(x))
else:
raise NotImplementedError

def logpmf(self, x):
if self.__logpmf is not None:
return self.__logpmf(x)
elif self.__pmf is not None:
return np.log(self.__pmf(x))
else:
raise NotImplementedError


class ContinuousRandomVariable(RandomVariable[_ValueType]):
pass


def asrandvar(obj) -> RandomVariable:
"""
Return ``obj`` as a :class:`RandomVariable`.
Expand Down Expand Up @@ -796,16 +864,8 @@ def _scipystats_to_rv(
return rvs.Normal(
mean=scipyrv.mean, cov=scipyrv.cov, random_state=scipyrv.random_state,
)
# Generic distributions
if (
hasattr(scipyrv, "dist") and isinstance(scipyrv.dist, scipy.stats.rv_discrete)
) or hasattr(scipyrv, "pmf"):
pdf = getattr(scipyrv, "pmf", None)
logpdf = getattr(scipyrv, "logpmf", None)
else:
pdf = getattr(scipyrv, "pdf", None)
logpdf = getattr(scipyrv, "logpdf", None)

# Generic distributions
def _wrap_np_scalar(fn):
if fn is None:
return None
Expand All @@ -820,17 +880,40 @@ def _wrapper(*args, **kwargs):

return _wrapper

if (
hasattr(scipyrv, "dist") and isinstance(scipyrv.dist, scipy.stats.rv_discrete)
) or hasattr(scipyrv, "pmf"):
rv_subclass = DiscreteRandomVariable
rv_subclass_kwargs = {
"pmf": _wrap_np_scalar(getattr(scipyrv, "pmf", None)),
"logpmf": _wrap_np_scalar(getattr(scipyrv, "logpmf", None)),
}
else:
rv_subclass = ContinuousRandomVariable
rv_subclass_kwargs = {
"pdf": _wrap_np_scalar(getattr(scipyrv, "pdf", None)),
"logpdf": _wrap_np_scalar(getattr(scipyrv, "logpdf", None)),
}

if isinstance(scipyrv, scipy.stats._distn_infrastructure.rv_frozen):

def in_support(x):
low, high = scipyrv.support()

return bool(low <= x <= high)

else:
in_support = None

# Infer shape and dtype
sample = _wrap_np_scalar(scipyrv.rvs)()

return RandomVariable(
return rv_subclass(
shape=sample.shape,
dtype=sample.dtype,
random_state=getattr(scipyrv, "random_state", None),
sample=_wrap_np_scalar(getattr(scipyrv, "rvs", None)),
in_support=None, # TODO for univariate
pdf=_wrap_np_scalar(pdf),
logpdf=_wrap_np_scalar(logpdf),
in_support=in_support,
cdf=_wrap_np_scalar(getattr(scipyrv, "cdf", None)),
logcdf=_wrap_np_scalar(getattr(scipyrv, "logcdf", None)),
quantile=_wrap_np_scalar(getattr(scipyrv, "ppf", None)),
Expand All @@ -841,4 +924,5 @@ def _wrapper(*args, **kwargs):
var=_wrap_np_scalar(getattr(scipyrv, "var", None)),
std=_wrap_np_scalar(getattr(scipyrv, "std", None)),
entropy=_wrap_np_scalar(getattr(scipyrv, "entropy", None)),
**rv_subclass_kwargs,
)

0 comments on commit d6f1911

Please sign in to comment.