Skip to content

Commit

Permalink
Normal random variable should cast arguments to floating point dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinpfoertner committed Aug 24, 2020
1 parent d6f1911 commit 0c9ffb1
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 10 deletions.
25 changes: 25 additions & 0 deletions src/probnum/core/random_variables/_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,31 @@ def __init__(
if np.isscalar(cov):
cov = _utils.as_numpy_scalar(cov)

# Data type normalization
is_mean_floating = mean.dtype is not None and np.issubdtype(
mean.dtype, np.floating
)
is_cov_floating = cov.dtype is not None and np.issubdtype(
cov.dtype, np.floating
)

if is_mean_floating and is_cov_floating:
dtype = np.promote_types(mean.dtype, cov.dtype)
elif is_mean_floating:
dtype = mean.dtype
elif is_cov_floating:
dtype = cov.dtype
else:
dtype = np.float_

# TODO: Implement casting for linear operators
if not isinstance(mean, linops.LinearOperator):
mean = mean.astype(dtype, order="C", casting="safe", subok=True, copy=False)

# TODO: Implement casting for linear operators
if not isinstance(cov, linops.LinearOperator):
cov = cov.astype(dtype, order="C", casting="safe", subok=True, copy=False)

# Shape checking
if len(mean.shape) not in [0, 1, 2]:
raise ValueError(
Expand Down
20 changes: 10 additions & 10 deletions src/probnum/core/random_variables/_random_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,18 @@ def __init__(
parameters: Optional[Dict[str, Any]] = None,
sample: Optional[Callable[[ShapeArgType], _ValueType]] = None,
in_support: Optional[Callable[[_ValueType], bool]] = None,
pdf: Optional[Callable[[_ValueType], np.float64]] = None,
logpdf: Optional[Callable[[_ValueType], np.float64]] = None,
cdf: Optional[Callable[[_ValueType], np.float64]] = None,
logcdf: Optional[Callable[[_ValueType], np.float64]] = None,
pdf: Optional[Callable[[_ValueType], np.float_]] = None,
logpdf: Optional[Callable[[_ValueType], np.float_]] = None,
cdf: Optional[Callable[[_ValueType], np.float_]] = None,
logcdf: Optional[Callable[[_ValueType], np.float_]] = None,
quantile: Optional[Callable[[FloatArgType], _ValueType]] = None,
mode: Optional[Callable[[], _ValueType]] = None,
median: Optional[Callable[[], _ValueType]] = None,
mean: Optional[Callable[[], _ValueType]] = None,
cov: Optional[Callable[[], _ValueType]] = None,
var: Optional[Callable[[], _ValueType]] = None,
std: Optional[Callable[[], _ValueType]] = None,
entropy: Optional[Callable[[], np.float64]] = None,
entropy: Optional[Callable[[], np.float_]] = None,
):
"""Create a new random variable."""
self._shape = RandomVariable._check_shape(shape)
Expand Down Expand Up @@ -704,18 +704,18 @@ def __init__(
parameters: Optional[Dict[str, Any]] = None,
sample: Optional[Callable[[ShapeArgType], _ValueType]] = None,
in_support: Optional[Callable[[_ValueType], bool]] = None,
pmf: Optional[Callable[[_ValueType], np.float64]] = None,
logpmf: Optional[Callable[[_ValueType], np.float64]] = None,
cdf: Optional[Callable[[_ValueType], np.float64]] = None,
logcdf: Optional[Callable[[_ValueType], np.float64]] = None,
pmf: Optional[Callable[[_ValueType], np.float_]] = None,
logpmf: Optional[Callable[[_ValueType], np.float_]] = None,
cdf: Optional[Callable[[_ValueType], np.float_]] = None,
logcdf: Optional[Callable[[_ValueType], np.float_]] = None,
quantile: Optional[Callable[[FloatArgType], _ValueType]] = None,
mode: Optional[Callable[[], _ValueType]] = None,
median: Optional[Callable[[], _ValueType]] = None,
mean: Optional[Callable[[], _ValueType]] = None,
cov: Optional[Callable[[], _ValueType]] = None,
var: Optional[Callable[[], _ValueType]] = None,
std: Optional[Callable[[], _ValueType]] = None,
entropy: Optional[Callable[[], np.float64]] = None,
entropy: Optional[Callable[[], np.float_]] = None,
):
# Probability mass function
self.__pmf = pmf
Expand Down

0 comments on commit 0c9ffb1

Please sign in to comment.