Skip to content

Commit

Permalink
Allow opting out of model nesting
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jun 10, 2024
1 parent bc52c86 commit 9baf76b
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 41 deletions.
18 changes: 12 additions & 6 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Literal,
Optional,
TypeVar,
Union,
cast,
overload,
)
Expand Down Expand Up @@ -468,6 +469,11 @@ class Model(WithMemoization, metaclass=ContextMeta):
# Variable will belong to root and second
z = pm.Normal("z", mu=y) # Variable wil be named "root::second::z"
# Set None for standalone model
with pm.Model(name="third", model=None) as third:
# Variable will belong to third only
w = pm.Normal("w") # Variable wil be named "third::w"
Set `check_bounds` to False for models with only continuous variables and default transformers
PyMC will remove the bounds check from the model logp which can speed up sampling
Expand All @@ -488,13 +494,13 @@ def __enter__(self: Self) -> Self: ...

def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None: ...

def __new__(cls, *args, **kwargs):
def __new__(cls, *args, model: Union[UNSET, None, "Model"] = UNSET, **kwargs):
# resolves the parent instance
instance = super().__new__(cls)
if kwargs.get("model") is not None:
instance._parent = kwargs.get("model")
else:
if model is UNSET:
instance._parent = cls.get_context(error_if_none=False)
else:
instance._parent = model
return instance

@staticmethod
Expand All @@ -510,9 +516,9 @@ def __init__(
check_bounds=True,
*,
coords_mutable=None,
model=None,
model: Union[UNSET, None, "Model"] = UNSET,
):
del model # used in __new__
del model # used in __new__ to define the parent of this model
self.name = self._validate_name(name)
self.check_bounds = check_bounds

Expand Down
4 changes: 1 addition & 3 deletions pymc/model/fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,7 @@ def first_non_model_var(var):
else:
return var

model = Model()
if model.parent is not None:
raise RuntimeError("model_to_fgraph cannot be called inside a PyMC model context")
model = Model(model=None) # Do not inherit from any model in the context manager

_coords = getattr(fgraph, "_coords", {})
_dim_lengths = getattr(fgraph, "_dim_lengths", {})
Expand Down
2 changes: 1 addition & 1 deletion pymc/model/transform/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pytensor import Variable
from pytensor.graph import ancestors

from pymc import Model
from pymc.model.core import Model
from pymc.model.fgraph import (
ModelObservedRV,
ModelVar,
Expand Down
2 changes: 1 addition & 1 deletion pymc/model/transform/conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from pytensor.graph import ancestors
from pytensor.tensor import TensorVariable

from pymc import Model
from pymc.logprob.transforms import Transform
from pymc.logprob.utils import rvs_in_graph
from pymc.model.core import Model
from pymc.model.fgraph import (
ModelDeterministic,
ModelFreeRV,
Expand Down
2 changes: 1 addition & 1 deletion pymc/sampling/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def compute_deterministics(
model = modelcontext(model)

if var_names is None:
deterministics = model.deterministics
deterministics = list(model.deterministics)
var_names = [det.name for det in deterministics]
else:
deterministics = [model[var_name] for var_name in var_names]
Expand Down
39 changes: 15 additions & 24 deletions pymc/stats/log_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

__all__ = ("compute_log_likelihood", "compute_log_prior")

from pymc.model.transform.conditioning import remove_value_transforms


def compute_log_likelihood(
idata: InferenceData,
Expand Down Expand Up @@ -126,46 +128,35 @@ def compute_log_density(
if kind not in ("likelihood", "prior"):
raise ValueError("kind must be either 'likelihood' or 'prior'")

# We need to disable transforms, because the InferenceData only keeps the untransformed values
umodel = remove_value_transforms(model)

if kind == "likelihood":
target_rvs = model.observed_RVs
target_rvs = list(umodel.observed_RVs)
target_str = "observed_RVs"
else:
target_rvs = model.free_RVs
target_rvs = list(umodel.free_RVs)
target_str = "free_RVs"

if var_names is None:
vars = target_rvs
var_names = tuple(rv.name for rv in vars)
else:
vars = [model.named_vars[name] for name in var_names]
vars = [umodel.named_vars[name] for name in var_names]
if not set(vars).issubset(target_rvs):
raise ValueError(f"var_names must refer to {target_str} in the model. Got: {var_names}")

# We need to temporarily disable transforms, because the InferenceData only keeps the untransformed values
try:
original_rvs_to_values = model.rvs_to_values
original_rvs_to_transforms = model.rvs_to_transforms

model.rvs_to_values = {
rv: rv.clone() if rv not in model.observed_RVs else value
for rv, value in model.rvs_to_values.items()
}
model.rvs_to_transforms = {rv: None for rv in model.basic_RVs}

elemwise_logdens_fn = model.compile_fn(
inputs=model.value_vars,
outs=model.logp(vars=vars, sum=False),
on_unused_input="ignore",
)
finally:
model.rvs_to_values = original_rvs_to_values
model.rvs_to_transforms = original_rvs_to_transforms
elemwise_logdens_fn = umodel.compile_fn(
inputs=umodel.value_vars,
outs=umodel.logp(vars=vars, sum=False),
on_unused_input="ignore",
)

coords, dims = coords_and_dims_for_inferencedata(model)
coords, dims = coords_and_dims_for_inferencedata(umodel)

logdens_dataset = apply_function_over_dataset(
elemwise_logdens_fn,
posterior[[rv.name for rv in model.free_RVs]],
posterior[[rv.name for rv in umodel.free_RVs]],
output_var_names=var_names,
sample_dims=sample_dims,
dims=dims,
Expand Down
16 changes: 14 additions & 2 deletions tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,20 @@ def test_docstring_example(self):
# Variable will belong to root and second
z = pm.Normal("z", mu=y) # Variable wil be named "root::second::z"

# Set None for standalone model
with pm.Model(name="third", model=None) as third:
# Variable will belong to third only
w = pm.Normal("w") # Variable wil be named "third::w"

assert x.name == "root::x"
assert y.name == "root::first::y"
assert z.name == "root::second::z"
assert w.name == "third::w"

assert set(root.basic_RVs) == {x, y, z}
assert set(first.basic_RVs) == {y}
assert set(second.basic_RVs) == {z}
assert set(third.basic_RVs) == {w}


class TestNested:
Expand Down Expand Up @@ -1106,11 +1113,16 @@ def test_model_parent_set_programmatically():
y = pm.Normal("y")

with model:
# Explict None opts out of model context
with pm.Model():
z_in = pm.Normal("z_in")

with pm.Model(model=None):
z = pm.Normal("z")
z_out = pm.Normal("z_out")

assert "y" in model.named_vars
assert "z" in model.named_vars
assert "z_in" not in model.named_vars
assert "z_out" not in model.named_vars


class TestModelContext:
Expand Down
11 changes: 8 additions & 3 deletions tests/model/test_fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,15 @@ def test_context_error():
with pm.Model() as m:
x = pm.Normal("x")

fg = fgraph_from_model(m)
fg, _ = fgraph_from_model(m)

with pytest.raises(RuntimeError, match="cannot be called inside a PyMC model context"):
model_from_fgraph(fg)
new_m = model_from_fgraph(fg)
new_x = new_m["x"]

assert new_m.parent is None
assert x != new_x
assert m.named_vars == {"x": x}
assert new_m.named_vars == {"x": new_x}


def test_sub_model_error():
Expand Down

0 comments on commit 9baf76b

Please sign in to comment.