Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add event_ndims to Transform class to fix shape errors in non-autobatching mode #321

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: help venv conda docker docstyle format style types black test lint check notebooks
.PHONY: help venv conda docker docstyle format style types black test lint check notebooks
.DEFAULT_GOAL = help

PYTHON = python
Expand Down Expand Up @@ -62,8 +62,9 @@ types:
black: # Format code in-place using black.
black pymc4/ tests/


notebooks: notebooks/*
jupyter nbconvert --config nbconfig.py --execute --ExecutePreprocessor.kernel_name="pymc4-dev" --ExecutePreprocessor.timeout=1200
jupyter nbconvert --config nbconfig.py --execute --ExecutePreprocessor.kernel_name="pymc4-dev" --ExecutePreprocessor.timeout=1200 --to 'html'
rm notebooks/*.html

test: # Test code using pytest.
Expand Down
8 changes: 8 additions & 0 deletions pymc4/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import tensorflow as tf
from tensorflow_probability import distributions as tfd
from tensorflow_probability.python.internal import prefer_static

from pymc4.coroutine_model import Model, unpack
from pymc4.distributions.batchstack import BatchStacker
from pymc4.distributions import transforms
Expand Down Expand Up @@ -78,6 +80,12 @@ def __init__(
if event_stack is not None:
self._distribution = tfd.Sample(self._distribution, sample_shape=self.event_stack)

if self.transform is not None and self.transform.event_ndims is None:
event_ndims = prefer_static.rank_from_shape(
self._distribution.event_shape_tensor, self._distribution.event_shape
)
self.transform.event_ndims = event_ndims

@property
def dtype(self):
return self._distribution.dtype
Expand Down
8 changes: 6 additions & 2 deletions pymc4/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,11 @@ class MvNormal(ContinuousDistribution):
Define a multivariate normal variable for a given covariance
matrix.

>>> import numpy as np
>>> import pymc4 as pm
>>> covariance_matrix = np.array([[1., 0.5], [0.5, 2]])
>>> mu = np.zeros(2)
>>> vals = pm.MvNormal('vals', loc=loc, covariance_matrix=covariance_matrix, shape=(5, 2))
>>> vals = pm.MvNormal('vals', loc=mu, covariance_matrix=covariance_matrix, shape=(5, 2))
"""

def __init__(self, name, loc, covariance_matrix, **kwargs):
Expand Down Expand Up @@ -360,10 +362,12 @@ class MvNormalCholesky(ContinuousDistribution):
Define a multivariate normal variable for a given cholesky
factor of the full covariance matrix (scale_tril).

>>> import numpy as np
>>> import pymc4 as pm
>>> covariance_matrix = np.array([[1., 0.5], [0.5, 2]])
>>> chol_factor = np.linalg.cholesky(covariance_matrix)
>>> mu = np.zeros(2)
>>> vals = pm.MvNormalCholesky('vals', loc=loc, scale_tril=chol_factor)
>>> vals = pm.MvNormalCholesky('vals', loc=mu, scale_tril=chol_factor)
"""

def __init__(self, name, loc, scale_tril, **kwargs):
Expand Down
87 changes: 68 additions & 19 deletions pymc4/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from tensorflow_probability import bijectors as tfb


__all__ = ["Log", "Sigmoid", "LowerBound", "UpperBound", "Interval"]


Expand All @@ -12,9 +13,49 @@ class JacobianPreference(enum.Enum):


class Transform:
"""
Baseclass to define a bijective transformation of a distribution.

Parameters
----------
transform : tfp.bijectors.bijector
The bijector that is called for the transformation.
event_ndims : int
The number of event dimensions of the distribution that has to be transformed.
This is normally automatically set during the initialization of the
PyMC4 distribution
"""

name: Optional[str] = None
jacobian_preference = JacobianPreference.Forward

def __init__(self, transform=None, event_ndims=None):
self._transform = transform
self.event_ndims = event_ndims

@property
def _untransformed_event_ndims(self):
"""
The length of the event_shape of the untransformed distribution. If set to None,
it can lead to errors in case of non autobatched sampling.
"""
if self.event_ndims is None:
return self._min_event_ndims
else:
return self.event_ndims

@property
def _transformed_event_ndims(self):
""" The length of the event_shape of the transformed distribution"""
if self.event_ndims is None:
return self._transform.inverse_event_ndims(self._min_event_ndims)
else:
return self._transform.inverse_event_ndims(self.event_ndims)

@property
def _min_event_ndims(self):
return NotImplementedError

def forward(self, x):
"""
Forward of a bijector.
Expand Down Expand Up @@ -93,12 +134,12 @@ def inverse_log_det_jacobian(self, z):


class Invert(Transform):
def __init__(self, transform):
def __init__(self, transform, **kwargs):
if transform.jacobian_preference == JacobianPreference.Forward:
self.jacobian_preference = JacobianPreference.Backward
else:
self.jacobian_preference = JacobianPreference.Forward
self._transform = transform
super().__init__(transform, **kwargs)

def forward(self, x):
return self._transform.inverse(x)
Expand All @@ -107,19 +148,27 @@ def inverse(self, z):
return self._transform.forward(z)

def forward_log_det_jacobian(self, x):
return self._transform.inverse_log_det_jacobian(x)
return self._transform.inverse_log_det_jacobian(x, self._untransformed_event_ndims)

def inverse_log_det_jacobian(self, z):
return self._transform.forward_log_det_jacobian(z)
return self._transform.forward_log_det_jacobian(z, self._transformed_event_ndims)


class BackwardTransform(Transform):
"""Base class for Transforms with Jacobian Preference as Backward"""
"""
Base class for Transforms with Jacobian Preference as Backward.
Backward means that the transformed values are in the domain of the specified function
and the untransformed values in the codomain.
"""

JacobianPreference = JacobianPreference.Backward

def __init__(self, transform):
self._transform = transform
def __init__(self, transform, **kwargs):
super().__init__(transform, **kwargs)

@property
def _min_event_ndims(self):
return self._transform._inverse_min_event_ndims

def forward(self, x):
return self._transform.inverse(x)
Expand All @@ -128,54 +177,54 @@ def inverse(self, z):
return self._transform.forward(z)

def forward_log_det_jacobian(self, x):
return self._transform.inverse_log_det_jacobian(x, self._transform.inverse_min_event_ndims)
return self._transform.inverse_log_det_jacobian(x, self._untransformed_event_ndims)

def inverse_log_det_jacobian(self, z):
return self._transform.forward_log_det_jacobian(z, self._transform.forward_min_event_ndims)
return self._transform.forward_log_det_jacobian(z, self._transformed_event_ndims)


class Log(BackwardTransform):
name = "log"

def __init__(self):
def __init__(self, **kwargs):
# NOTE: We actually need the inverse to match PyMC3, do we?
transform = tfb.Exp()
super().__init__(transform)
super().__init__(transform, **kwargs)


class Sigmoid(BackwardTransform):
name = "sigmoid"

def __init__(self):
def __init__(self, **kwargs):
transform = tfb.Sigmoid()
super().__init__(transform)
super().__init__(transform, **kwargs)


class LowerBound(BackwardTransform):
""""Transformation to interval [lower_limit, inf]"""

name = "lowerbound"

def __init__(self, lower_limit):
def __init__(self, lower_limit, **kwargs):
transform = tfb.Chain([tfb.Shift(lower_limit), tfb.Exp()])
super().__init__(transform)
super().__init__(transform, **kwargs)


class UpperBound(BackwardTransform):
""""Transformation to interval [-inf, upper_limit]"""

name = "upperbound"

def __init__(self, upper_limit):
def __init__(self, upper_limit, **kwargs):
transform = tfb.Chain([tfb.Shift(upper_limit), tfb.Scale(-1), tfb.Exp()])
super().__init__(transform)
super().__init__(transform, **kwargs)


class Interval(BackwardTransform):
""""Transformation to interval [lower_limit, upper_limit]"""

name = "interval"

def __init__(self, lower_limit, upper_limit):
def __init__(self, lower_limit, upper_limit, **kwargs):
transform = tfb.Sigmoid(low=lower_limit, high=upper_limit)
super().__init__(transform)
super().__init__(transform, **kwargs)
9 changes: 5 additions & 4 deletions pymc4/inference/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,26 +113,27 @@ def sample(
Let's start with a simple model. We'll need some imports to experiment with it.
>>> import pymc4 as pm
>>> import numpy as np
This particular model has a latent variable `sd`
>>> # This particular model has a latent variable `sd`
>>> @pm.model
... def nested_model(cond):
... sd = yield pm.HalfNormal("sd", 1.)
... norm = yield pm.Normal("n", cond, sd, observed=np.random.randn(10))
... return norm
Now, we may want to perform sampling from this model. We already observed some variables and we
>>> # Now, we may want to perform sampling from this model. We already observed some variables and we
now need to fix the condition.
>>> conditioned = nested_model(cond=2.)
Passing ``cond=2.`` we condition our model for future evaluation. Now we go to sampling.
>>> # Passing ``cond=2.`` we condition our model for future evaluation. Now we go to sampling.
Nothing special is required but passing the model to ``pm.sample``, the rest configuration is
held by PyMC4.
>>> trace = sample(conditioned)

Notes
-----
Things that are considered to be under discussion are overriding observed variables. The API
for that may look like
>>> new_observed = {"nested_model/n": np.random.randn(10) + 1}
>>> trace = sample(conditioned, observed=new_observed)
This will give a trace with new observed variables. This way is considered to be explicit.
>>> # This will give a trace with new observed variables. This way is considered to be explicit.
"""
# assign sampler is no sampler_type is passed``
sampler_assigned: str = auto_assign_sampler(model, sampler_type)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
arviz>=0.5.1
xarray<=0.16.0 # because the version 0.16.1 throws an error when imported by arviz
gast>=0.3.2
tf-nightly
tfp-nightly
Expand Down