From 5055262868473b61f487f42d82542f9d0fc0e273 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> Date: Wed, 4 Dec 2024 13:37:52 +0100 Subject: [PATCH] Add `find_MAP` with close JAX integration and fix bug with Laplace fit (#385) * Add JAX-based `find_MAP` * add `better_optimize` to CI envs * Fix relative import * Remove `find_MAP` import from module-level `__init__.py` * Update docstring * Allow calling `find_MAP` inside model context without model argument * Required patched better_optimize * in-progress refactor * More refactor * Generalize code to use any pytensor backend * Reconcile the two laplace approximation functions * Use absolute import in doctest * Fix imports * Fix unrelated statespace test * - Rename argument `use_jax_gradients` -> `gradient_backend` - Rename function `laplace` -> `sample_laplace_posterior` * Fix typo introduced by rename refactor * use `mode=FAST_COMPILE` to get `unobserved_value_vars` after MAP optimization * Rename `test_jax_find_map.py` -> `test_find_map.py` * Improve docstring for `fit_laplace` * Update tests to match new signature * Update docstring --- conda-envs/environment-test.yml | 1 + conda-envs/windows-environment-test.yml | 1 + pymc_experimental/inference/find_map.py | 431 +++++++++++++++++ pymc_experimental/inference/fit.py | 4 +- pymc_experimental/inference/laplace.py | 588 +++++++++++++++++++----- tests/statespace/test_ETS.py | 2 +- tests/test_find_map.py | 98 ++++ tests/test_laplace.py | 216 ++++++--- 8 files changed, 1178 insertions(+), 163 deletions(-) create mode 100644 pymc_experimental/inference/find_map.py create mode 100644 tests/test_find_map.py diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 4deda063..2c84fb6e 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -13,3 +13,4 @@ dependencies: - pymc>=5.17.0 # CI was failing to resolve - blackjax - scikit-learn + - better_optimize>=0.0.10 diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 4deda063..2c84fb6e 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -13,3 +13,4 @@ dependencies: - pymc>=5.17.0 # CI was failing to resolve - blackjax - scikit-learn + - better_optimize>=0.0.10 diff --git a/pymc_experimental/inference/find_map.py b/pymc_experimental/inference/find_map.py new file mode 100644 index 00000000..72ce3b19 --- /dev/null +++ b/pymc_experimental/inference/find_map.py @@ -0,0 +1,431 @@ +import logging + +from collections.abc import Callable +from typing import Literal, cast, get_args + +import jax +import numpy as np +import pymc as pm +import pytensor +import pytensor.tensor as pt + +from better_optimize import minimize +from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method +from pymc.blocking import DictToArrayBijection, RaveledVars +from pymc.initial_point import make_initial_point_fn +from pymc.model.transform.optimization import freeze_dims_and_data +from pymc.pytensorf import join_nonshared_inputs +from pymc.util import get_default_varnames +from pytensor.compile import Function +from pytensor.compile.mode import Mode +from pytensor.tensor import TensorVariable +from scipy.optimize import OptimizeResult + +_log = logging.getLogger(__name__) + +GradientBackend = Literal["pytensor", "jax"] +VALID_BACKENDS = get_args(GradientBackend) + + +def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp): + method_info = MINIMIZE_MODE_KWARGS[method].copy() + + use_grad = use_grad if use_grad is not None else method_info["uses_grad"] + use_hess = use_hess if use_hess is not None else method_info["uses_hess"] + use_hessp = use_hessp if use_hessp is not None else method_info["uses_hessp"] + + if use_hess and use_hessp: + use_hess = False + + return use_grad, use_hess, use_hessp + + +def get_nearest_psd(A: np.ndarray) -> np.ndarray: + """ + Compute the nearest positive semi-definite matrix to a given matrix. + + This function takes a square matrix and returns the nearest positive semi-definite matrix using + eigenvalue decomposition. It ensures all eigenvalues are non-negative. The "nearest" matrix is defined in terms + of the Frobenius norm. + + Parameters + ---------- + A : np.ndarray + Input square matrix. + + Returns + ------- + np.ndarray + The nearest positive semi-definite matrix to the input matrix. + """ + C = (A + A.T) / 2 + eigval, eigvec = np.linalg.eig(C) + eigval[eigval < 0] = 0 + + return eigvec @ np.diag(eigval) @ eigvec.T + + +def _unconstrained_vector_to_constrained_rvs(model): + constrained_rvs, unconstrained_vector = join_nonshared_inputs( + model.initial_point(), + inputs=model.value_vars, + outputs=get_default_varnames(model.unobserved_value_vars, include_transformed=False), + ) + + unconstrained_vector.name = "unconstrained_vector" + return constrained_rvs, unconstrained_vector + + +def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model, chains, draws): + X = pt.tensor("transformed_draws", shape=(chains, draws, H_inv.shape[0])) + out = [] + for rv, idx in slices.items(): + f = model.rvs_to_transforms[rv] + untransformed_X = f.backward(X[..., idx]) if f is not None else X[..., idx] + + if rv in out_shapes: + new_shape = (chains, draws) + out_shapes[rv] + untransformed_X = untransformed_X.reshape(new_shape) + + out.append(untransformed_X) + + f_untransform = pytensor.function( + inputs=[pytensor.In(X, borrow=True)], + outputs=pytensor.Out(out, borrow=True), + mode=Mode(linker="py", optimizer="FAST_COMPILE"), + ) + return f_untransform(posterior_draws) + + +def _compile_jax_gradients( + f_loss: Function, use_hess: bool, use_hessp: bool +) -> tuple[Callable | None, Callable | None]: + """ + Compile loss function gradients using JAX. + + Parameters + ---------- + f_loss: Function + The loss function to compile gradients for. Expected to be a pytensor function that returns a scalar loss, + compiled with mode="JAX". + use_hess: bool + Whether to compile a function to compute the hessian of the loss function. + use_hessp: bool + Whether to compile a function to compute the hessian-vector product of the loss function. + + Returns + ------- + f_loss_and_grad: Callable + The compiled loss function and gradient function. + f_hess: Callable | None + The compiled hessian function, or None if use_hess is False. + f_hessp: Callable | None + The compiled hessian-vector product function, or None if use_hessp is False. + """ + f_hess = None + f_hessp = None + + orig_loss_fn = f_loss.vm.jit_fn + + @jax.jit + def loss_fn_jax_grad(x, *shared): + return jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x) + + f_loss_and_grad = loss_fn_jax_grad + + if use_hessp: + + def f_hessp_jax(x, p): + y, u = jax.jvp(lambda x: f_loss_and_grad(x)[1], (x,), (p,)) + return jax.numpy.stack(u) + + f_hessp = jax.jit(f_hessp_jax) + + if use_hess: + _f_hess_jax = jax.jacfwd(lambda x: f_loss_and_grad(x)[1]) + + def f_hess_jax(x): + return jax.numpy.stack(_f_hess_jax(x)) + + f_hess = jax.jit(f_hess_jax) + + return f_loss_and_grad, f_hess, f_hessp + + +def _compile_functions( + loss: TensorVariable, + inputs: list[TensorVariable], + compute_grad: bool, + compute_hess: bool, + compute_hessp: bool, + compile_kwargs: dict | None = None, +) -> list[Function] | list[Function, Function | None, Function | None]: + """ + Compile loss functions for use with scipy.optimize.minimize. + + Parameters + ---------- + loss: TensorVariable + The loss function to compile. + inputs: list[TensorVariable] + A single flat vector input variable, collecting all inputs to the loss function. Scipy optimize routines + expect the function signature to be f(x, *args), where x is a 1D array of parameters. + compute_grad: bool + Whether to compile a function that computes the gradients of the loss function. + compute_hess: bool + Whether to compile a function that computes the Hessian of the loss function. + compute_hessp: bool + Whether to compile a function that computes the Hessian-vector product of the loss function. + compile_kwargs: dict, optional + Additional keyword arguments to pass to the ``pm.compile_pymc`` function. + + Returns + ------- + f_loss: Function + + f_hess: Function | None + f_hessp: Function | None + """ + loss = pm.pytensorf.rewrite_pregrad(loss) + f_hess = None + f_hessp = None + + if compute_grad: + grads = pytensor.gradient.grad(loss, inputs) + grad = pt.concatenate([grad.ravel() for grad in grads]) + f_loss_and_grad = pm.compile_pymc(inputs, [loss, grad], **compile_kwargs) + else: + f_loss = pm.compile_pymc(inputs, loss, **compile_kwargs) + return [f_loss] + + if compute_hess: + hess = pytensor.gradient.jacobian(grad, inputs)[0] + f_hess = pm.compile_pymc(inputs, hess, **compile_kwargs) + + if compute_hessp: + p = pt.tensor("p", shape=inputs[0].type.shape) + hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p) + f_hessp = pm.compile_pymc([*inputs, p], hessp[0], **compile_kwargs) + + return [f_loss_and_grad, f_hess, f_hessp] + + +def scipy_optimize_funcs_from_loss( + loss: TensorVariable, + inputs: list[TensorVariable], + initial_point_dict: dict[str, np.ndarray | float | int], + use_grad: bool, + use_hess: bool, + use_hessp: bool, + gradient_backend: GradientBackend = "pytensor", + compile_kwargs: dict | None = None, +) -> tuple[Callable, ...]: + """ + Compile loss functions for use with scipy.optimize.minimize. + + Parameters + ---------- + loss: TensorVariable + The loss function to compile. + inputs: list[TensorVariable] + The input variables to the loss function. + initial_point_dict: dict[str, np.ndarray | float | int] + Dictionary mapping variable names to initial values. Used to determine the shapes of the input variables. + use_grad: bool + Whether to compile a function that computes the gradients of the loss function. + use_hess: bool + Whether to compile a function that computes the Hessian of the loss function. + use_hessp: bool + Whether to compile a function that computes the Hessian-vector product of the loss function. + gradient_backend: str, default "pytensor" + Which backend to use to compute gradients. Must be one of "jax" or "pytensor" + compile_kwargs: + Additional keyword arguments to pass to the ``pm.compile_pymc`` function. + + Returns + ------- + f_loss: Callable + The compiled loss function. + f_hess: Callable | None + The compiled hessian function, or None if use_hess is False. + f_hessp: Callable | None + The compiled hessian-vector product function, or None if use_hessp is False. + """ + + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + + if (use_hess or use_hessp) and not use_grad: + raise ValueError( + "Cannot compute hessian or hessian-vector product without also computing the gradient" + ) + + if gradient_backend not in VALID_BACKENDS: + raise ValueError( + f"Invalid gradient backend: {gradient_backend}. Must be one of {VALID_BACKENDS}" + ) + + use_jax_gradients = (gradient_backend == "jax") and use_grad + + mode = compile_kwargs.get("mode", None) + if mode is None and use_jax_gradients: + compile_kwargs["mode"] = "JAX" + elif mode != "JAX" and use_jax_gradients: + raise ValueError( + 'jax gradients can only be used when ``compile_kwargs["mode"]`` is set to "JAX"' + ) + + if not isinstance(inputs, list): + inputs = [inputs] + + [loss], flat_input = join_nonshared_inputs( + point=initial_point_dict, outputs=[loss], inputs=inputs + ) + + compute_grad = use_grad and not use_jax_gradients + compute_hess = use_hess and not use_jax_gradients + compute_hessp = use_hessp and not use_jax_gradients + + funcs = _compile_functions( + loss=loss, + inputs=[flat_input], + compute_grad=compute_grad, + compute_hess=compute_hess, + compute_hessp=compute_hessp, + compile_kwargs=compile_kwargs, + ) + + # f_loss here is f_loss_and_grad if compute_grad = True. The name is unchanged to simplify the return values + f_loss = funcs.pop(0) + f_hess = funcs.pop(0) if compute_grad else None + f_hessp = funcs.pop(0) if compute_grad else None + + if use_jax_gradients: + # f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values + f_loss, f_hess, f_hessp = _compile_jax_gradients(f_loss, use_hess, use_hessp) + + return f_loss, f_hess, f_hessp + + +def find_MAP( + method: minimize_method, + *, + model: pm.Model | None = None, + use_grad: bool | None = None, + use_hessp: bool | None = None, + use_hess: bool | None = None, + initvals: dict | None = None, + random_seed: int | np.random.Generator | None = None, + return_raw: bool = False, + jitter_rvs: list[TensorVariable] | None = None, + progressbar: bool = True, + include_transformed: bool = True, + gradient_backend: GradientBackend = "pytensor", + compile_kwargs: dict | None = None, + **optimizer_kwargs, +) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], OptimizeResult]: + """ + Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.minimize. + + Parameters + ---------- + model : pm.Model + The PyMC model to be fit. If None, the current model context is used. + method : str + The optimization method to use. See scipy.optimize.minimize documentation for details. + use_grad : bool | None, optional + Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on + the ``method``. + use_hessp : bool | None, optional + Whether to use Hessian-vector products in the optimization. Defaults to None, which determines this automatically based on + the ``method``. + use_hess : bool | None, optional + Whether to use the Hessian matrix in the optimization. Defaults to None, which determines this automatically based on + the ``method``. + initvals : None | dict, optional + Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted. + If None, the model's default initial values are used. + random_seed : None | int | np.random.Generator, optional + Seed for the random number generator or a numpy Generator for reproducibility + return_raw: bool | False, optinal + Whether to also return the full output of `scipy.optimize.minimize` + jitter_rvs : list of TensorVariables, optional + Variables whose initial values should be jittered. If None, all variables are jittered. + progressbar : bool, optional + Whether to display a progress bar during optimization. Defaults to True. + include_transformed: bool, optional + Whether to include transformed variable values in the returned dictionary. Defaults to True. + gradient_backend: str, default "pytensor" + Which backend to use to compute gradients. Must be one of "pytensor" or "jax". + compile_kwargs: dict, optional + Additional options to pass to the ``pytensor.function`` function when compiling loss functions. + **optimizer_kwargs + Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function. + + Returns + ------- + optimizer_result: dict[str, np.ndarray] or tuple[dict[str, np.ndarray], OptimizerResult] + Dictionary with names of random variables as keys, and optimization results as values. If return_raw is True, + also returns the object returned by ``scipy.optimize.minimize``. + """ + model = pm.modelcontext(model) + frozen_model = freeze_dims_and_data(model) + + jitter_rvs = [] if jitter_rvs is None else jitter_rvs + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + + ipfn = make_initial_point_fn( + model=frozen_model, + jitter_rvs=set(jitter_rvs), + return_transformed=True, + overrides=initvals, + ) + + start_dict = ipfn(random_seed) + vars_dict = {var.name: var for var in frozen_model.continuous_value_vars} + initial_params = DictToArrayBijection.map( + {var_name: value for var_name, value in start_dict.items() if var_name in vars_dict} + ) + use_grad, use_hess, use_hessp = set_optimizer_function_defaults( + method, use_grad, use_hess, use_hessp + ) + + f_logp, f_hess, f_hessp = scipy_optimize_funcs_from_loss( + loss=-frozen_model.logp(jacobian=False), + inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars, + initial_point_dict=start_dict, + use_grad=use_grad, + use_hess=use_hess, + use_hessp=use_hessp, + gradient_backend=gradient_backend, + compile_kwargs=compile_kwargs, + ) + + args = optimizer_kwargs.pop("args", None) + + # better_optimize.minimize will check if f_logp is a fused loss+grad Op, and automatically assign the jac argument + # if so. That is why it is not set here, regardless of user settings. + optimizer_result = minimize( + f=f_logp, + x0=cast(np.ndarray[float], initial_params.data), + args=args, + hess=f_hess, + hessp=f_hessp, + progressbar=progressbar, + method=method, + **optimizer_kwargs, + ) + + raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info) + unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed) + unobserved_vars_values = model.compile_fn(unobserved_vars, mode="FAST_COMPILE")( + DictToArrayBijection.rmap(raveled_optimized) + ) + + optimized_point = { + var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values) + } + + if return_raw: + return optimized_point, optimizer_result + + return optimized_point diff --git a/pymc_experimental/inference/fit.py b/pymc_experimental/inference/fit.py index f6c87d90..7897aeed 100644 --- a/pymc_experimental/inference/fit.py +++ b/pymc_experimental/inference/fit.py @@ -39,6 +39,6 @@ def fit(method, **kwargs): return fit_pathfinder(**kwargs) if method == "laplace": - from pymc_experimental.inference.laplace import laplace + from pymc_experimental.inference.laplace import fit_laplace - return laplace(**kwargs) + return fit_laplace(**kwargs) diff --git a/pymc_experimental/inference/laplace.py b/pymc_experimental/inference/laplace.py index 7d7beb59..24a72c0f 100644 --- a/pymc_experimental/inference/laplace.py +++ b/pymc_experimental/inference/laplace.py @@ -12,156 +12,197 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings -from collections.abc import Sequence +import logging + +from functools import reduce +from itertools import product +from typing import Literal import arviz as az import numpy as np import pymc as pm +import pytensor +import pytensor.tensor as pt import xarray as xr from arviz import dict_to_dataset +from better_optimize.constants import minimize_method +from pymc import DictToArrayBijection from pymc.backends.arviz import ( coords_and_dims_for_inferencedata, find_constants, find_observations, ) +from pymc.blocking import RaveledVars from pymc.model.transform.conditioning import remove_value_transforms -from pymc.util import RandomSeed -from pytensor import Variable +from pymc.model.transform.optimization import freeze_dims_and_data +from pymc.util import get_default_varnames +from scipy import stats + +from pymc_experimental.inference.find_map import ( + GradientBackend, + _unconstrained_vector_to_constrained_rvs, + find_MAP, + get_nearest_psd, + scipy_optimize_funcs_from_loss, +) +_log = logging.getLogger(__name__) -def laplace( - vars: Sequence[Variable], - draws: int | None = 1000, - model=None, - random_seed: RandomSeed | None = None, - progressbar=True, -): + +def laplace_draws_to_inferencedata( + posterior_draws: list[np.ndarray[float | int]], model: pm.Model | None = None +) -> az.InferenceData: """ - Create a Laplace (quadratic) approximation for a posterior distribution. + Convert draws from a posterior estimated with the Laplace approximation to an InferenceData object. - This function generates a Laplace approximation for a given posterior distribution using a specified - number of draws. This is useful for obtaining a parametric approximation to the posterior distribution - that can be used for further analysis. Parameters ---------- - vars : Sequence[Variable] - A sequence of variables for which the Laplace approximation of the posterior distribution - is to be created. - draws : Optional[int] with default=1_000 - The number of draws to sample from the posterior distribution for creating the approximation. - For draws=None only the fit of the Laplace approximation is returned - model : object, optional, default=None - The model object that defines the posterior distribution. If None, the default model will be used. - random_seed : Optional[RandomSeed], optional, default=None - An optional random seed to ensure reproducibility of the draws. If None, the draws will be - generated using the current random state. - progressbar: bool, optional defaults to True - Whether to display a progress bar in the command line. + posterior_draws: list of np.ndarray + A list of arrays containing the posterior draws. Each array should have shape (chains, draws, *shape), where + shape is the shape of the variable in the posterior. + model: Model, optional + A PyMC model. If None, the model is taken from the current model context. Returns ------- - arviz.InferenceData - An `InferenceData` object from the `arviz` library containing the Laplace - approximation of the posterior distribution. The inferenceData object also - contains constant and observed data as well as deterministic variables. - InferenceData also contains a group 'fit' with the mean and covariance - for the Laplace approximation. - - Examples - -------- - >>> import numpy as np - >>> import pymc as pm - >>> import arviz as az - >>> from pymc_experimental.inference.laplace import laplace - >>> y = np.array([2642, 3503, 4358]*10) - >>> with pm.Model() as m: - >>> logsigma = pm.Uniform("logsigma", 1, 100) - >>> mu = pm.Uniform("mu", -10000, 10000) - >>> yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) - >>> idata = laplace([mu, logsigma], model=m) - - Notes - ----- - This method of approximation may not be suitable for all types of posterior distributions, - especially those with significant skewness or multimodality. - - See Also - -------- - fit : Calling the inference function 'fit' like pmx.fit(method="laplace", vars=[mu, logsigma], model=m) - will forward the call to 'laplace'. - + idata: az.InferenceData + An InferenceData object containing the approximated posterior samples """ - - rng = np.random.default_rng(seed=random_seed) - - transformed_m = pm.modelcontext(model) - - if len(vars) != len(transformed_m.free_RVs): - warnings.warn( - "Number of variables in vars does not eqaul the number of variables in the model.", - UserWarning, + model = pm.modelcontext(model) + chains, draws, *_ = posterior_draws[0].shape + + def make_rv_coords(name): + coords = {"chain": range(chains), "draw": range(draws)} + extra_dims = model.named_vars_to_dims.get(name) + if extra_dims is None: + return coords + return coords | {dim: list(model.coords[dim]) for dim in extra_dims} + + def make_rv_dims(name): + dims = ["chain", "draw"] + extra_dims = model.named_vars_to_dims.get(name) + if extra_dims is None: + return dims + return dims + list(extra_dims) + + names = [ + x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=False) + ] + idata = { + name: xr.DataArray( + data=draws, + coords=make_rv_coords(name), + dims=make_rv_dims(name), + name=name, ) + for name, draws in zip(names, posterior_draws) + } - map = pm.find_MAP(vars=vars, progressbar=progressbar, model=transformed_m) + coords, dims = coords_and_dims_for_inferencedata(model) + idata = az.convert_to_inference_data(idata, coords=coords, dims=dims) - # See https://www.pymc.io/projects/docs/en/stable/api/model/generated/pymc.model.transform.conditioning.remove_value_transforms.html - untransformed_m = remove_value_transforms(transformed_m) - untransformed_vars = [untransformed_m[v.name] for v in vars] - hessian = pm.find_hessian(point=map, vars=untransformed_vars, model=untransformed_m) + return idata - if np.linalg.det(hessian) == 0: - raise np.linalg.LinAlgError("Hessian is singular.") - cov = np.linalg.inv(hessian) - mean = np.concatenate([np.atleast_1d(map[v.name]) for v in vars]) +def add_fit_to_inferencedata( + idata: az.InferenceData, mu: RaveledVars, H_inv: np.ndarray, model: pm.Model | None = None +) -> az.InferenceData: + """ + Add the mean vector and covariance matrix of the Laplace approximation to an InferenceData object. - chains = 1 - if draws is not None: - samples = rng.multivariate_normal(mean, cov, size=(chains, draws)) + Parameters + ---------- + idata: az.InfereceData + An InferenceData object containing the approximated posterior samples. + mu: RaveledVars + The MAP estimate of the model parameters. + H_inv: np.ndarray + The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate. + model: Model, optional + A PyMC model. If None, the model is taken from the current model context. - data_vars = {} - for i, var in enumerate(vars): - data_vars[str(var)] = xr.DataArray(samples[:, :, i], dims=("chain", "draw")) + Returns + ------- + idata: az.InferenceData + The provided InferenceData, with the mean vector and covariance matrix added to the "fit" group. + """ + model = pm.modelcontext(model) + coords = model.coords - coords = {"chain": np.arange(chains), "draw": np.arange(draws)} - ds = xr.Dataset(data_vars, coords=coords) + variable_names, *_ = zip(*mu.point_map_info) - idata = az.convert_to_inference_data(ds) - idata = addDataToInferenceData(model, idata, progressbar) - else: - idata = az.InferenceData() + def make_unpacked_variable_names(name): + value_to_dim = { + x.name: model.named_vars_to_dims.get(model.values_to_rvs[x].name, None) + for x in model.value_vars + } + value_to_dim = {k: v for k, v in value_to_dim.items() if v is not None} - idata = addFitToInferenceData(vars, idata, mean, cov) + rv_to_dim = model.named_vars_to_dims + dims_dict = rv_to_dim | value_to_dim - return idata + dims = dims_dict.get(name) + if dims is None: + return [name] + labels = product(*(coords[dim] for dim in dims)) + return [f"{name}[{','.join(map(str, label))}]" for label in labels] + unpacked_variable_names = reduce( + lambda lst, name: lst + make_unpacked_variable_names(name), variable_names, [] + ) -def addFitToInferenceData(vars, idata, mean, covariance): - coord_names = [v.name for v in vars] - # Convert to xarray DataArray - mean_dataarray = xr.DataArray(mean, dims=["rows"], coords={"rows": coord_names}) + mean_dataarray = xr.DataArray(mu.data, dims=["rows"], coords={"rows": unpacked_variable_names}) cov_dataarray = xr.DataArray( - covariance, dims=["rows", "columns"], coords={"rows": coord_names, "columns": coord_names} + H_inv, + dims=["rows", "columns"], + coords={"rows": unpacked_variable_names, "columns": unpacked_variable_names}, ) - # Create xarray dataset dataset = xr.Dataset({"mean_vector": mean_dataarray, "covariance_matrix": cov_dataarray}) - idata.add_groups(fit=dataset) return idata -def addDataToInferenceData(model, trace, progressbar): - # Add deterministic variables to inference data - trace.posterior = pm.compute_deterministics( - trace.posterior, model=model, merge_dataset=True, progressbar=progressbar - ) +def add_data_to_inferencedata( + idata: az.InferenceData, + progressbar: bool = True, + model: pm.Model | None = None, + compile_kwargs: dict | None = None, +) -> az.InferenceData: + """ + Add observed and constant data to an InferenceData object. + + Parameters + ---------- + idata: az.InferenceData + An InferenceData object containing the approximated posterior samples. + progressbar: bool + Whether to display a progress bar during computations. Default is True. + model: Model, optional + A PyMC model. If None, the model is taken from the current model context. + compile_kwargs: dict, optional + Additional keyword arguments to pass to pytensor.function. + + Returns + ------- + idata: az.InferenceData + The provided InferenceData, with observed and constant data added. + """ + model = pm.modelcontext(model) + + if model.deterministics: + idata.posterior = pm.compute_deterministics( + idata.posterior, + model=model, + merge_dataset=True, + progressbar=progressbar, + compile_kwargs=compile_kwargs, + ) coords, dims = coords_and_dims_for_inferencedata(model) @@ -181,10 +222,349 @@ def addDataToInferenceData(model, trace, progressbar): default_dims=[], ) - trace.add_groups( + idata.add_groups( {"observed_data": observed_data, "constant_data": constant_data}, coords=coords, dims=dims, ) - return trace + return idata + + +def fit_mvn_to_MAP( + optimized_point: dict[str, np.ndarray], + model: pm.Model | None = None, + on_bad_cov: Literal["warn", "error", "ignore"] = "ignore", + transform_samples: bool = False, + gradient_backend: GradientBackend = "pytensor", + zero_tol: float = 1e-8, + diag_jitter: float | None = 1e-8, + compile_kwargs: dict | None = None, +) -> tuple[RaveledVars, np.ndarray]: + """ + Create a multivariate normal distribution using the inverse of the negative Hessian matrix of the log-posterior + evaluated at the MAP estimate. This is the basis of the Laplace approximation. + + Parameters + ---------- + optimized_point : dict[str, np.ndarray] + Local maximum a posteriori (MAP) point returned from pymc.find_MAP or jax_tools.fit_map + model : Model, optional + A PyMC model. If None, the model is taken from the current model context. + on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore' + What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite. + If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned. + If 'error', an error will be raised. + transform_samples : bool + Whether to transform the samples back to the original parameter space. Default is True. + gradient_backend: str, default "pytensor" + The backend to use for gradient computations. Must be one of "pytensor" or "jax". + zero_tol: float + Value below which an element of the Hessian matrix is counted as 0. + This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8. + diag_jitter: float | None + A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite. + If None, no jitter is added. Default is 1e-8. + compile_kwargs: dict, optional + Additional keyword arguments to pass to pytensor.function when compiling loss functions + + Returns + ------- + map_estimate: RaveledVars + The MAP estimate of the model parameters, raveled into a 1D array. + + inverse_hessian: np.ndarray + The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate. + """ + model = pm.modelcontext(model) + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + frozen_model = freeze_dims_and_data(model) + + if not transform_samples: + untransformed_model = remove_value_transforms(frozen_model) + logp = untransformed_model.logp(jacobian=False) + variables = untransformed_model.continuous_value_vars + else: + logp = frozen_model.logp(jacobian=True) + variables = frozen_model.continuous_value_vars + + variable_names = {var.name for var in variables} + optimized_free_params = {k: v for k, v in optimized_point.items() if k in variable_names} + mu = DictToArrayBijection.map(optimized_free_params) + + _, f_hess, _ = scipy_optimize_funcs_from_loss( + loss=-logp, + inputs=variables, + initial_point_dict=optimized_free_params, + use_grad=True, + use_hess=True, + use_hessp=False, + gradient_backend=gradient_backend, + compile_kwargs=compile_kwargs, + ) + + H = -f_hess(mu.data) + H_inv = np.linalg.pinv(np.where(np.abs(H) < zero_tol, 0, -H)) + + def stabilize(x, jitter): + return x + np.eye(x.shape[0]) * jitter + + H_inv = H_inv if diag_jitter is None else stabilize(H_inv, diag_jitter) + + try: + np.linalg.cholesky(H_inv) + except np.linalg.LinAlgError: + if on_bad_cov == "error": + raise np.linalg.LinAlgError( + "Inverse Hessian not positive-semi definite at the provided point" + ) + H_inv = get_nearest_psd(H_inv) + if on_bad_cov == "warn": + _log.warning( + "Inverse Hessian is not positive semi-definite at the provided point, using the closest PSD " + "matrix in L1-norm instead" + ) + + return mu, H_inv + + +def sample_laplace_posterior( + mu: RaveledVars, + H_inv: np.ndarray, + model: pm.Model | None = None, + chains: int = 2, + draws: int = 500, + transform_samples: bool = False, + progressbar: bool = True, + random_seed: int | np.random.Generator | None = None, + compile_kwargs: dict | None = None, +) -> az.InferenceData: + """ + Generate samples from a multivariate normal distribution with mean `mu` and inverse covariance matrix `H_inv`. + + Parameters + ---------- + mu + H_inv + model : Model + A PyMC model + chains : int + The number of sampling chains running in parallel. Default is 2. + draws : int + The number of samples to draw from the approximated posterior. Default is 500. + transform_samples : bool + Whether to transform the samples back to the original parameter space. Default is True. + progressbar : bool + Whether to display a progress bar during computations. Default is True. + random_seed: int | np.random.Generator | None + Seed for the random number generator or a numpy Generator for reproducibility + + Returns + ------- + idata: az.InferenceData + An InferenceData object containing the approximated posterior samples. + """ + model = pm.modelcontext(model) + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + rng = np.random.default_rng(random_seed) + + posterior_dist = stats.multivariate_normal( + mean=mu.data, cov=H_inv, allow_singular=True, seed=rng + ) + posterior_draws = posterior_dist.rvs(size=(chains, draws)) + + if transform_samples: + constrained_rvs, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(model) + batched_values = pt.tensor( + "batched_values", + shape=(chains, draws, *unconstrained_vector.type.shape), + dtype=unconstrained_vector.type.dtype, + ) + batched_rvs = pytensor.graph.vectorize_graph( + constrained_rvs, replace={unconstrained_vector: batched_values} + ) + + f_constrain = pm.compile_pymc( + inputs=[batched_values], outputs=batched_rvs, **compile_kwargs + ) + posterior_draws = f_constrain(posterior_draws) + + else: + info = mu.point_map_info + flat_shapes = [np.prod(shape).astype(int) for _, shape, _ in info] + slices = [ + slice(sum(flat_shapes[:i]), sum(flat_shapes[: i + 1])) for i in range(len(flat_shapes)) + ] + + posterior_draws = [ + posterior_draws[..., idx].reshape((chains, draws, *shape)).astype(dtype) + for idx, (name, shape, dtype) in zip(slices, info) + ] + + idata = laplace_draws_to_inferencedata(posterior_draws, model) + idata = add_fit_to_inferencedata(idata, mu, H_inv) + idata = add_data_to_inferencedata(idata, progressbar, model, compile_kwargs) + + return idata + + +def fit_laplace( + optimize_method: minimize_method = "BFGS", + *, + model: pm.Model | None = None, + use_grad: bool | None = None, + use_hessp: bool | None = None, + use_hess: bool | None = None, + initvals: dict | None = None, + random_seed: int | np.random.Generator | None = None, + return_raw: bool = False, + jitter_rvs: list[pt.TensorVariable] | None = None, + progressbar: bool = True, + include_transformed: bool = True, + gradient_backend: GradientBackend = "pytensor", + chains: int = 2, + draws: int = 500, + on_bad_cov: Literal["warn", "error", "ignore"] = "ignore", + fit_in_unconstrained_space: bool = False, + zero_tol: float = 1e-8, + diag_jitter: float | None = 1e-8, + optimizer_kwargs: dict | None = None, + compile_kwargs: dict | None = None, +) -> az.InferenceData: + """ + Create a Laplace (quadratic) approximation for a posterior distribution. + + This function generates a Laplace approximation for a given posterior distribution using a specified + number of draws. This is useful for obtaining a parametric approximation to the posterior distribution + that can be used for further analysis. + + Parameters + ---------- + model : pm.Model + The PyMC model to be fit. If None, the current model context is used. + optimize_method : str + The optimization method to use. See scipy.optimize.minimize documentation for details. + use_grad : bool | None, optional + Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on + the ``method``. + use_hessp : bool | None, optional + Whether to use Hessian-vector products in the optimization. Defaults to None, which determines this automatically based on + the ``method``. + use_hess : bool | None, optional + Whether to use the Hessian matrix in the optimization. Defaults to None, which determines this automatically based on + the ``method``. + initvals : None | dict, optional + Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted. + If None, the model's default initial values are used. + random_seed : None | int | np.random.Generator, optional + Seed for the random number generator or a numpy Generator for reproducibility + return_raw: bool | False, optinal + Whether to also return the full output of `scipy.optimize.minimize` + jitter_rvs : list of TensorVariables, optional + Variables whose initial values should be jittered. If None, all variables are jittered. + progressbar : bool, optional + Whether to display a progress bar during optimization. Defaults to True. + fit_in_unconstrained_space: bool, default False + Whether to fit the Laplace approximation in the unconstrained parameter space. If True, samples will be drawn + from a mean and covariance matrix computed at a point in the **unconstrained** parameter space. Samples will + then be transformed back to the original parameter space. This will guarantee that the samples will respect + the domain of prior distributions (for exmaple, samples from a Beta distribution will be strictly between 0 + and 1). + + .. warning:: + This argumnet should be considered highly experimental. It has not been verified if this method produces + valid draws from the posterior. **Use at your own risk**. + + gradient_backend: str, default "pytensor" + The backend to use for gradient computations. Must be one of "pytensor" or "jax". + chains: int, default: 2 + The number of sampling chains running in parallel. + draws: int, default: 500 + The number of samples to draw from the approximated posterior. + on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore' + What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite. + If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned. + If 'error', an error will be raised. + zero_tol: float + Value below which an element of the Hessian matrix is counted as 0. + This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8. + diag_jitter: float | None + A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite. + If None, no jitter is added. Default is 1e-8. + optimizer_kwargs: dict, optional + Additional keyword arguments to pass to scipy.minimize. See the documentation for scipy.optimize.minimize for + details. Arguments that are typically passed via ``options`` will be automatically extracted without the need + to use a nested dictionary. + compile_kwargs: dict, optional + Additional keyword arguments to pass to pytensor.function. + + Returns + ------- + idata: az.InferenceData + An InferenceData object containing the approximated posterior samples. + + Examples + -------- + >>> from pymc_experimental.inference.laplace import fit_laplace + >>> import numpy as np + >>> import pymc as pm + >>> import arviz as az + >>> y = np.array([2642, 3503, 4358]*10) + >>> with pm.Model() as m: + >>> logsigma = pm.Uniform("logsigma", 1, 100) + >>> mu = pm.Uniform("mu", -10000, 10000) + >>> yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) + >>> idata = fit_laplace() + + Notes + ----- + This method of approximation may not be suitable for all types of posterior distributions, + especially those with significant skewness or multimodality. + + See Also + -------- + fit : Calling the inference function 'fit' like pmx.fit(method="laplace", model=m) + will forward the call to 'fit_laplace'. + + """ + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs + + optimized_point = find_MAP( + method=optimize_method, + model=model, + use_grad=use_grad, + use_hessp=use_hessp, + use_hess=use_hess, + initvals=initvals, + random_seed=random_seed, + return_raw=return_raw, + jitter_rvs=jitter_rvs, + progressbar=progressbar, + include_transformed=include_transformed, + gradient_backend=gradient_backend, + compile_kwargs=compile_kwargs, + **optimizer_kwargs, + ) + + mu, H_inv = fit_mvn_to_MAP( + optimized_point=optimized_point, + model=model, + on_bad_cov=on_bad_cov, + transform_samples=fit_in_unconstrained_space, + zero_tol=zero_tol, + diag_jitter=diag_jitter, + compile_kwargs=compile_kwargs, + ) + + return sample_laplace_posterior( + mu=mu, + H_inv=H_inv, + model=model, + chains=chains, + draws=draws, + transform_samples=fit_in_unconstrained_space, + progressbar=progressbar, + random_seed=random_seed, + compile_kwargs=compile_kwargs, + ) diff --git a/tests/statespace/test_ETS.py b/tests/statespace/test_ETS.py index b56a3581..b9c15e5e 100644 --- a/tests/statespace/test_ETS.py +++ b/tests/statespace/test_ETS.py @@ -408,4 +408,4 @@ def test_ETS_stationary_initialization(): R, Q = outputs["selection"], outputs["state_cov"] P0_expected = linalg.solve_discrete_lyapunov(T_stationary, R @ Q @ R.T) - assert_allclose(outputs["initial_state_cov"], P0_expected) + assert_allclose(outputs["initial_state_cov"], P0_expected, rtol=1e-8, atol=1e-8) diff --git a/tests/test_find_map.py b/tests/test_find_map.py new file mode 100644 index 00000000..6b2c029a --- /dev/null +++ b/tests/test_find_map.py @@ -0,0 +1,98 @@ +import numpy as np +import pymc as pm +import pytensor.tensor as pt +import pytest + +from pymc_experimental.inference.find_map import ( + GradientBackend, + find_MAP, + scipy_optimize_funcs_from_loss, +) + +pytest.importorskip("jax") + + +@pytest.fixture(scope="session") +def rng(): + seed = sum(map(ord, "test_fit_map")) + return np.random.default_rng(seed) + + +@pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str) +def test_jax_functions_from_graph(gradient_backend: GradientBackend): + x = pt.tensor("x", shape=(2,)) + + def compute_z(x): + z1 = x[0] ** 2 + 2 + z2 = x[0] * x[1] + 3 + return z1, z2 + + z = pt.stack(compute_z(x)) + f_loss, f_hess, f_hessp = scipy_optimize_funcs_from_loss( + loss=z.sum(), + inputs=[x], + initial_point_dict={"x": np.array([1.0, 2.0])}, + use_grad=True, + use_hess=True, + use_hessp=True, + gradient_backend=gradient_backend, + compile_kwargs=dict(mode="JAX"), + ) + + x_val = np.array([1.0, 2.0]) + expected_z = sum(compute_z(x_val)) + + z_jax, grad_val = f_loss(x_val) + np.testing.assert_allclose(z_jax, expected_z) + np.testing.assert_allclose(grad_val.squeeze(), np.array([2 * x_val[0] + x_val[1], x_val[0]])) + + hess_val = np.array(f_hess(x_val)) + np.testing.assert_allclose(hess_val.squeeze(), np.array([[2, 1], [1, 0]])) + + hessp_val = np.array(f_hessp(x_val, np.array([1.0, 0.0]))) + np.testing.assert_allclose(hessp_val.squeeze(), np.array([2, 1])) + + +@pytest.mark.parametrize( + "method, use_grad, use_hess", + [ + ("nelder-mead", False, False), + ("powell", False, False), + ("CG", True, False), + ("BFGS", True, False), + ("L-BFGS-B", True, False), + ("TNC", True, False), + ("SLSQP", True, False), + ("dogleg", True, True), + ("trust-ncg", True, True), + ("trust-exact", True, True), + ("trust-krylov", True, True), + ("trust-constr", True, True), + ], +) +@pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str) +def test_JAX_map(method, use_grad, use_hess, gradient_backend: GradientBackend, rng): + extra_kwargs = {} + if method == "dogleg": + # HACK -- dogleg requires that the hessian of the objective function is PSD, so we have to pick a point + # where this is true + extra_kwargs = {"initvals": {"mu": 2, "sigma_log__": 1}} + + with pm.Model() as m: + mu = pm.Normal("mu") + sigma = pm.Exponential("sigma", 1) + pm.Normal("y_hat", mu=mu, sigma=sigma, observed=rng.normal(loc=3, scale=1.5, size=100)) + + optimized_point = find_MAP( + method=method, + **extra_kwargs, + use_grad=use_grad, + use_hess=use_hess, + progressbar=False, + gradient_backend=gradient_backend, + compile_kwargs={"mode": "JAX"}, + ) + mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"] + + assert np.isclose(mu_hat, 3, atol=0.5) + assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5) diff --git a/tests/test_laplace.py b/tests/test_laplace.py index 3fefe3f7..a11ee59e 100644 --- a/tests/test_laplace.py +++ b/tests/test_laplace.py @@ -19,6 +19,19 @@ import pymc_experimental as pmx +from pymc_experimental.inference.find_map import find_MAP +from pymc_experimental.inference.laplace import ( + fit_laplace, + fit_mvn_to_MAP, + sample_laplace_posterior, +) + + +@pytest.fixture(scope="session") +def rng(): + seed = sum(map(ord, "test_laplace")) + return np.random.default_rng(seed) + @pytest.mark.filterwarnings( "ignore:hessian will stop negating the output in a future version of PyMC.\n" @@ -35,18 +48,15 @@ def test_laplace(): draws = 100000 with pm.Model() as m: - logsigma = pm.Uniform("logsigma", 1, 100) mu = pm.Uniform("mu", -10000, 10000) + logsigma = pm.Uniform("logsigma", 1, 100) + yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) vars = [mu, logsigma] - idata = pmx.fit( - method="laplace", - vars=vars, - model=m, - draws=draws, - random_seed=173300, - ) + idata = pmx.fit( + method="laplace", optimize_method="trust-ncg", draws=draws, random_seed=173300, chains=1 + ) assert idata.posterior["mu"].shape == (1, draws) assert idata.posterior["logsigma"].shape == (1, draws) @@ -57,14 +67,10 @@ def test_laplace(): bda_map = [y.mean(), np.log(y.std())] bda_cov = np.array([[y.var() / n, 0], [0, 1 / (2 * n)]]) - assert np.allclose(idata.fit["mean_vector"].values, bda_map) - assert np.allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4) + np.testing.assert_allclose(idata.fit["mean_vector"].values, bda_map) + np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4) -@pytest.mark.filterwarnings( - "ignore:hessian will stop negating the output in a future version of PyMC.\n" - + "To suppress this warning set `negate_output=False`:FutureWarning", -) def test_laplace_only_fit(): # Example originates from Bayesian Data Analyses, 3rd Edition # By Andrew Gelman, John Carlin, Hal Stern, David Dunson, @@ -80,55 +86,153 @@ def test_laplace_only_fit(): yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) vars = [mu, logsigma] - idata = pmx.fit( - method="laplace", - vars=vars, - draws=None, - model=m, - random_seed=173300, - ) + idata = pmx.fit( + method="laplace", + optimize_method="BFGS", + progressbar=True, + gradient_backend="jax", + compile_kwargs={"mode": "JAX"}, + optimizer_kwargs=dict(maxiter=100_000, gtol=1e-100), + random_seed=173300, + ) assert idata.fit["mean_vector"].shape == (len(vars),) assert idata.fit["covariance_matrix"].shape == (len(vars), len(vars)) - bda_map = [y.mean(), np.log(y.std())] - bda_cov = np.array([[y.var() / n, 0], [0, 1 / (2 * n)]]) + bda_map = [np.log(y.std()), y.mean()] + bda_cov = np.array([[1 / (2 * n), 0], [0, y.var() / n]]) - assert np.allclose(idata.fit["mean_vector"].values, bda_map) - assert np.allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4) + np.testing.assert_allclose(idata.fit["mean_vector"].values, bda_map) + np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4) -@pytest.mark.filterwarnings( - "ignore:hessian will stop negating the output in a future version of PyMC.\n" - + "To suppress this warning set `negate_output=False`:FutureWarning", +@pytest.mark.parametrize( + "transform_samples", + [True, False], + ids=["transformed", "untransformed"], ) -def test_laplace_subset_of_rv(recwarn): - # Example originates from Bayesian Data Analyses, 3rd Edition - # By Andrew Gelman, John Carlin, Hal Stern, David Dunson, - # Aki Vehtari, and Donald Rubin. - # See section. 4.1 - - y = np.array([2642, 3503, 4358], dtype=np.float64) - n = y.size - - with pm.Model() as m: - logsigma = pm.Uniform("logsigma", 1, 100) - mu = pm.Uniform("mu", -10000, 10000) - yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) - vars = [mu] - - idata = pmx.fit( - method="laplace", - vars=vars, - draws=None, - model=m, - random_seed=173300, +@pytest.mark.parametrize("mode", ["JAX", None], ids=["jax", "pytensor"]) +def test_fit_laplace_coords(rng, transform_samples, mode): + coords = {"city": ["A", "B", "C"], "obs_idx": np.arange(100)} + with pm.Model(coords=coords) as model: + mu = pm.Normal("mu", mu=3, sigma=0.5, dims=["city"]) + sigma = pm.Exponential("sigma", 1, dims=["city"]) + obs = pm.Normal( + "obs", + mu=mu, + sigma=sigma, + observed=rng.normal(loc=3, scale=1.5, size=(100, 3)), + dims=["obs_idx", "city"], + ) + + optimized_point = find_MAP( + method="trust-ncg", + use_grad=True, + use_hessp=True, + progressbar=False, + compile_kwargs=dict(mode=mode), + gradient_backend="jax" if mode == "JAX" else "pytensor", + ) + + for value in optimized_point.values(): + assert value.shape == (3,) + + mu, H_inv = fit_mvn_to_MAP( + optimized_point=optimized_point, + model=model, + transform_samples=transform_samples, + ) + + idata = sample_laplace_posterior( + mu=mu, H_inv=H_inv, model=model, transform_samples=transform_samples + ) + + np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2, 3), 3), atol=0.5) + np.testing.assert_allclose( + np.mean(idata.posterior.sigma, axis=1), np.full((2, 3), 1.5), atol=0.3 ) - assert len(recwarn) == 3 - w = recwarn.pop(UserWarning) - assert issubclass(w.category, UserWarning) - assert ( - str(w.message) - == "Number of variables in vars does not eqaul the number of variables in the model." - ) + suffix = "_log__" if transform_samples else "" + assert idata.fit.rows.values.tolist() == [ + "mu[A]", + "mu[B]", + "mu[C]", + f"sigma{suffix}[A]", + f"sigma{suffix}[B]", + f"sigma{suffix}[C]", + ] + + +def test_fit_laplace_ragged_coords(rng): + coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)} + with pm.Model(coords=coords) as ragged_dim_model: + X = pm.Data("X", np.ones((100, 2)), dims=["obs_idx", "feature"]) + beta = pm.Normal( + "beta", mu=[[-100.0, 100.0], [-100.0, 100.0], [-100.0, 100.0]], dims=["city", "feature"] + ) + mu = pm.Deterministic( + "mu", (X[:, None, :] * beta[None]).sum(axis=-1), dims=["obs_idx", "city"] + ) + sigma = pm.Normal("sigma", mu=1.5, sigma=0.5, dims=["city"]) + + obs = pm.Normal( + "obs", + mu=mu, + sigma=sigma, + observed=rng.normal(loc=3, scale=1.5, size=(100, 3)), + dims=["obs_idx", "city"], + ) + + idata = fit_laplace( + optimize_method="Newton-CG", + progressbar=False, + use_grad=True, + use_hessp=True, + gradient_backend="jax", + compile_kwargs={"mode": "JAX"}, + ) + + assert idata["posterior"].beta.shape[-2:] == (3, 2) + assert idata["posterior"].sigma.shape[-1:] == (3,) + + # Check that everything got unraveled correctly -- feature 0 should be strictly negative, feature 1 + # strictly positive + assert (idata["posterior"].beta.sel(feature=0).to_numpy() < 0).all() + assert (idata["posterior"].beta.sel(feature=1).to_numpy() > 0).all() + + +@pytest.mark.parametrize( + "fit_in_unconstrained_space", + [True, False], + ids=["transformed", "untransformed"], +) +def test_fit_laplace(fit_in_unconstrained_space): + with pm.Model() as simp_model: + mu = pm.Normal("mu", mu=3, sigma=0.5) + sigma = pm.Exponential("sigma", 1) + obs = pm.Normal( + "obs", + mu=mu, + sigma=sigma, + observed=np.random.default_rng().normal(loc=3, scale=1.5, size=(10000,)), + ) + + idata = fit_laplace( + optimize_method="trust-ncg", + use_grad=True, + use_hessp=True, + fit_in_unconstrained_space=fit_in_unconstrained_space, + optimizer_kwargs=dict(maxiter=100_000, tol=1e-100), + ) + + np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2,), 3), atol=0.1) + np.testing.assert_allclose( + np.mean(idata.posterior.sigma, axis=1), np.full((2,), 1.5), atol=0.1 + ) + + if fit_in_unconstrained_space: + assert idata.fit.rows.values.tolist() == ["mu", "sigma_log__"] + np.testing.assert_allclose(idata.fit.mean_vector.values, np.array([3.0, 0.4]), atol=0.1) + else: + assert idata.fit.rows.values.tolist() == ["mu", "sigma"] + np.testing.assert_allclose(idata.fit.mean_vector.values, np.array([3.0, 1.5]), atol=0.1)