Skip to content

Commit

Permalink
Check model coords for unknown shapes when building predictive models (
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski authored Jan 12, 2025
1 parent dcc353c commit 9c7a6fb
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 1 deletion.
23 changes: 22 additions & 1 deletion pymc_extras/statespace/core/statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,10 +983,31 @@ def _build_dummy_graph(self) -> None:
list[pm.Flat]
A list of pm.Flat variables representing all parameters estimated by the model.
"""

def infer_variable_shape(name):
shape = self._name_to_variable[name].type.shape
if not any(dim is None for dim in shape):
return shape

dim_names = self._fit_dims.get(name, None)
if dim_names is None:
raise ValueError(
f"Could not infer shape for {name}, because it was not given coords during model"
f"fitting"
)

shape_from_coords = tuple([len(self._fit_coords[dim]) for dim in dim_names])
return tuple(
[
shape[i] if shape[i] is not None else shape_from_coords[i]
for i in range(len(shape))
]
)

for name in self.param_names:
pm.Flat(
name,
shape=self._name_to_variable[name].type.shape,
shape=infer_variable_shape(name),
dims=self._fit_dims.get(name, None),
)

Expand Down
54 changes: 54 additions & 0 deletions tests/statespace/test_statespace.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Sequence
from functools import partial

import numpy as np
Expand Down Expand Up @@ -349,6 +350,59 @@ def test_sampling_methods(group, kind, ss_mod, idata, rng):
assert not np.any(np.isnan(test_idata[f"{group}_{output}"].values))


@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
def test_sample_conditional_with_time_varying():
class TVCovariance(PyMCStateSpace):
def __init__(self):
super().__init__(k_states=1, k_endog=1, k_posdef=1)

def make_symbolic_graph(self) -> None:
self.ssm["transition", 0, 0] = 1.0

self.ssm["design", 0, 0] = 1.0

sigma_cov = self.make_and_register_variable("sigma_cov", (None,))
self.ssm["state_cov"] = sigma_cov[:, None, None] ** 2

@property
def param_names(self) -> list[str]:
return ["sigma_cov"]

@property
def coords(self) -> dict[str, Sequence[str]]:
return make_default_coords(self)

@property
def state_names(self) -> list[str]:
return ["level"]

@property
def observed_states(self) -> list[str]:
return ["level"]

@property
def shock_names(self) -> list[str]:
return ["level"]

ss_mod = TVCovariance()
empty_data = pd.DataFrame(
np.nan, index=pd.date_range("2020-01-01", periods=100, freq="D"), columns=["data"]
)

coords = ss_mod.coords
coords["time"] = empty_data.index
with pm.Model(coords=coords) as mod:
log_sigma_cov = pm.Normal("log_sigma_cov", mu=0, sigma=0.1, dims=["time"])
pm.Deterministic("sigma_cov", pm.math.exp(log_sigma_cov.cumsum()), dims=["time"])

ss_mod.build_statespace_graph(data=empty_data)

prior = pm.sample_prior_predictive(10)

ss_mod.sample_unconditional_prior(prior)
ss_mod.sample_conditional_prior(prior)


def _make_time_idx(mod, use_datetime_index=True):
if use_datetime_index:
mod._fit_coords["time"] = nile.index
Expand Down

0 comments on commit 9c7a6fb

Please sign in to comment.