diff --git a/Makefile b/Makefile index 89e225ba..abe1c739 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: help venv conda docker docstyle format style types black test lint check notebooks + .PHONY: help venv conda docker docstyle format style types black test lint check notebooks .DEFAULT_GOAL = help PYTHON = python @@ -62,8 +62,9 @@ types: black: # Format code in-place using black. black pymc4/ tests/ + notebooks: notebooks/* - jupyter nbconvert --config nbconfig.py --execute --ExecutePreprocessor.kernel_name="pymc4-dev" --ExecutePreprocessor.timeout=1200 + jupyter nbconvert --config nbconfig.py --execute --ExecutePreprocessor.kernel_name="pymc4-dev" --ExecutePreprocessor.timeout=1200 --to 'html' rm notebooks/*.html test: # Test code using pytest. diff --git a/pymc4/distributions/distribution.py b/pymc4/distributions/distribution.py index bfffd1e6..ee42250d 100755 --- a/pymc4/distributions/distribution.py +++ b/pymc4/distributions/distribution.py @@ -5,6 +5,8 @@ import tensorflow as tf from tensorflow_probability import distributions as tfd +from tensorflow_probability.python.internal import prefer_static + from pymc4.coroutine_model import Model, unpack from pymc4.distributions.batchstack import BatchStacker from pymc4.distributions import transforms @@ -78,6 +80,12 @@ def __init__( if event_stack is not None: self._distribution = tfd.Sample(self._distribution, sample_shape=self.event_stack) + if self.transform is not None and self.transform.event_ndims is None: + event_ndims = prefer_static.rank_from_shape( + self._distribution.event_shape_tensor, self._distribution.event_shape + ) + self.transform.event_ndims = event_ndims + @property def dtype(self): return self._distribution.dtype diff --git a/pymc4/distributions/multivariate.py b/pymc4/distributions/multivariate.py index d36b235f..7ede09ed 100755 --- a/pymc4/distributions/multivariate.py +++ b/pymc4/distributions/multivariate.py @@ -183,9 +183,11 @@ class MvNormal(ContinuousDistribution): Define a multivariate normal variable for a given covariance matrix. + >>> import numpy as np + >>> import pymc4 as pm >>> covariance_matrix = np.array([[1., 0.5], [0.5, 2]]) >>> mu = np.zeros(2) - >>> vals = pm.MvNormal('vals', loc=loc, covariance_matrix=covariance_matrix, shape=(5, 2)) + >>> vals = pm.MvNormal('vals', loc=mu, covariance_matrix=covariance_matrix, shape=(5, 2)) """ def __init__(self, name, loc, covariance_matrix, **kwargs): @@ -360,10 +362,12 @@ class MvNormalCholesky(ContinuousDistribution): Define a multivariate normal variable for a given cholesky factor of the full covariance matrix (scale_tril). + >>> import numpy as np + >>> import pymc4 as pm >>> covariance_matrix = np.array([[1., 0.5], [0.5, 2]]) >>> chol_factor = np.linalg.cholesky(covariance_matrix) >>> mu = np.zeros(2) - >>> vals = pm.MvNormalCholesky('vals', loc=loc, scale_tril=chol_factor) + >>> vals = pm.MvNormalCholesky('vals', loc=mu, scale_tril=chol_factor) """ def __init__(self, name, loc, scale_tril, **kwargs): diff --git a/pymc4/distributions/transforms.py b/pymc4/distributions/transforms.py index 2721a651..bce48d42 100755 --- a/pymc4/distributions/transforms.py +++ b/pymc4/distributions/transforms.py @@ -3,6 +3,7 @@ from tensorflow_probability import bijectors as tfb + __all__ = ["Log", "Sigmoid", "LowerBound", "UpperBound", "Interval"] @@ -12,9 +13,49 @@ class JacobianPreference(enum.Enum): class Transform: + """ + Baseclass to define a bijective transformation of a distribution. + + Parameters + ---------- + transform : tfp.bijectors.bijector + The bijector that is called for the transformation. + event_ndims : int + The number of event dimensions of the distribution that has to be transformed. + This is normally automatically set during the initialization of the + PyMC4 distribution + """ + name: Optional[str] = None jacobian_preference = JacobianPreference.Forward + def __init__(self, transform=None, event_ndims=None): + self._transform = transform + self.event_ndims = event_ndims + + @property + def _untransformed_event_ndims(self): + """ + The length of the event_shape of the untransformed distribution. If set to None, + it can lead to errors in case of non autobatched sampling. + """ + if self.event_ndims is None: + return self._min_event_ndims + else: + return self.event_ndims + + @property + def _transformed_event_ndims(self): + """ The length of the event_shape of the transformed distribution""" + if self.event_ndims is None: + return self._transform.inverse_event_ndims(self._min_event_ndims) + else: + return self._transform.inverse_event_ndims(self.event_ndims) + + @property + def _min_event_ndims(self): + return NotImplementedError + def forward(self, x): """ Forward of a bijector. @@ -93,12 +134,12 @@ def inverse_log_det_jacobian(self, z): class Invert(Transform): - def __init__(self, transform): + def __init__(self, transform, **kwargs): if transform.jacobian_preference == JacobianPreference.Forward: self.jacobian_preference = JacobianPreference.Backward else: self.jacobian_preference = JacobianPreference.Forward - self._transform = transform + super().__init__(transform, **kwargs) def forward(self, x): return self._transform.inverse(x) @@ -107,19 +148,27 @@ def inverse(self, z): return self._transform.forward(z) def forward_log_det_jacobian(self, x): - return self._transform.inverse_log_det_jacobian(x) + return self._transform.inverse_log_det_jacobian(x, self._untransformed_event_ndims) def inverse_log_det_jacobian(self, z): - return self._transform.forward_log_det_jacobian(z) + return self._transform.forward_log_det_jacobian(z, self._transformed_event_ndims) class BackwardTransform(Transform): - """Base class for Transforms with Jacobian Preference as Backward""" + """ + Base class for Transforms with Jacobian Preference as Backward. + Backward means that the transformed values are in the domain of the specified function + and the untransformed values in the codomain. + """ JacobianPreference = JacobianPreference.Backward - def __init__(self, transform): - self._transform = transform + def __init__(self, transform, **kwargs): + super().__init__(transform, **kwargs) + + @property + def _min_event_ndims(self): + return self._transform._inverse_min_event_ndims def forward(self, x): return self._transform.inverse(x) @@ -128,27 +177,27 @@ def inverse(self, z): return self._transform.forward(z) def forward_log_det_jacobian(self, x): - return self._transform.inverse_log_det_jacobian(x, self._transform.inverse_min_event_ndims) + return self._transform.inverse_log_det_jacobian(x, self._untransformed_event_ndims) def inverse_log_det_jacobian(self, z): - return self._transform.forward_log_det_jacobian(z, self._transform.forward_min_event_ndims) + return self._transform.forward_log_det_jacobian(z, self._transformed_event_ndims) class Log(BackwardTransform): name = "log" - def __init__(self): + def __init__(self, **kwargs): # NOTE: We actually need the inverse to match PyMC3, do we? transform = tfb.Exp() - super().__init__(transform) + super().__init__(transform, **kwargs) class Sigmoid(BackwardTransform): name = "sigmoid" - def __init__(self): + def __init__(self, **kwargs): transform = tfb.Sigmoid() - super().__init__(transform) + super().__init__(transform, **kwargs) class LowerBound(BackwardTransform): @@ -156,9 +205,9 @@ class LowerBound(BackwardTransform): name = "lowerbound" - def __init__(self, lower_limit): + def __init__(self, lower_limit, **kwargs): transform = tfb.Chain([tfb.Shift(lower_limit), tfb.Exp()]) - super().__init__(transform) + super().__init__(transform, **kwargs) class UpperBound(BackwardTransform): @@ -166,9 +215,9 @@ class UpperBound(BackwardTransform): name = "upperbound" - def __init__(self, upper_limit): + def __init__(self, upper_limit, **kwargs): transform = tfb.Chain([tfb.Shift(upper_limit), tfb.Scale(-1), tfb.Exp()]) - super().__init__(transform) + super().__init__(transform, **kwargs) class Interval(BackwardTransform): @@ -176,6 +225,6 @@ class Interval(BackwardTransform): name = "interval" - def __init__(self, lower_limit, upper_limit): + def __init__(self, lower_limit, upper_limit, **kwargs): transform = tfb.Sigmoid(low=lower_limit, high=upper_limit) - super().__init__(transform) + super().__init__(transform, **kwargs) diff --git a/pymc4/inference/sampling.py b/pymc4/inference/sampling.py index f3947c41..0d4a4b8f 100755 --- a/pymc4/inference/sampling.py +++ b/pymc4/inference/sampling.py @@ -113,26 +113,27 @@ def sample( Let's start with a simple model. We'll need some imports to experiment with it. >>> import pymc4 as pm >>> import numpy as np - This particular model has a latent variable `sd` + >>> # This particular model has a latent variable `sd` >>> @pm.model ... def nested_model(cond): ... sd = yield pm.HalfNormal("sd", 1.) ... norm = yield pm.Normal("n", cond, sd, observed=np.random.randn(10)) ... return norm - Now, we may want to perform sampling from this model. We already observed some variables and we + >>> # Now, we may want to perform sampling from this model. We already observed some variables and we now need to fix the condition. >>> conditioned = nested_model(cond=2.) - Passing ``cond=2.`` we condition our model for future evaluation. Now we go to sampling. + >>> # Passing ``cond=2.`` we condition our model for future evaluation. Now we go to sampling. Nothing special is required but passing the model to ``pm.sample``, the rest configuration is held by PyMC4. >>> trace = sample(conditioned) + Notes ----- Things that are considered to be under discussion are overriding observed variables. The API for that may look like >>> new_observed = {"nested_model/n": np.random.randn(10) + 1} >>> trace = sample(conditioned, observed=new_observed) - This will give a trace with new observed variables. This way is considered to be explicit. + >>> # This will give a trace with new observed variables. This way is considered to be explicit. """ # assign sampler is no sampler_type is passed`` sampler_assigned: str = auto_assign_sampler(model, sampler_type) diff --git a/requirements.txt b/requirements.txt index 1c3cbcdc..a846b4fe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ arviz>=0.5.1 +xarray<=0.16.0 # because the version 0.16.1 throws an error when imported by arviz gast>=0.3.2 tf-nightly tfp-nightly