From 6aa20f73e0299edf57af0a60176f35d2ca5a62dd Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 27 Oct 2024 14:17:37 +0800 Subject: [PATCH 01/21] Add JAX-based `find_MAP` --- pymc_experimental/__init__.py | 1 + pymc_experimental/inference/jax_find_map.py | 448 ++++++++++++++++++++ tests/test_jax_find_map.py | 186 ++++++++ 3 files changed, 635 insertions(+) create mode 100644 pymc_experimental/inference/jax_find_map.py create mode 100644 tests/test_jax_find_map.py diff --git a/pymc_experimental/__init__.py b/pymc_experimental/__init__.py index c19097e6..977e593c 100644 --- a/pymc_experimental/__init__.py +++ b/pymc_experimental/__init__.py @@ -16,6 +16,7 @@ from pymc_experimental import gp, statespace, utils from pymc_experimental.distributions import * from pymc_experimental.inference.fit import fit +from pymc_experimental.inference.jax_find_map import find_MAP from pymc_experimental.model.marginal.marginal_model import MarginalModel, marginalize from pymc_experimental.model.model_api import as_model from pymc_experimental.version import __version__ diff --git a/pymc_experimental/inference/jax_find_map.py b/pymc_experimental/inference/jax_find_map.py new file mode 100644 index 00000000..b1a41649 --- /dev/null +++ b/pymc_experimental/inference/jax_find_map.py @@ -0,0 +1,448 @@ +import logging + +from collections.abc import Callable +from typing import Literal, cast + +import arviz as az +import jax +import numpy as np +import pymc as pm +import pytensor +import pytensor.tensor as pt +import xarray as xr + +from arviz import dict_to_dataset +from better_optimize import minimize +from better_optimize.constants import minimize_method +from pymc.backends.arviz import ( + coords_and_dims_for_inferencedata, + find_constants, + find_observations, +) +from pymc.blocking import DictToArrayBijection, RaveledVars +from pymc.initial_point import make_initial_point_fn +from pymc.model.transform.conditioning import remove_value_transforms +from pymc.model.transform.optimization import freeze_dims_and_data +from pymc.sampling.jax import get_jaxified_graph +from pymc.util import get_default_varnames +from pytensor.tensor import TensorVariable +from scipy import stats +from scipy.optimize import OptimizeResult + +_log = logging.getLogger(__name__) + + +def get_near_psd(A: np.ndarray) -> np.ndarray: + """ + Compute the nearest positive semi-definite matrix to a given matrix. + + This function takes a square matrix and returns the nearest positive + semi-definite matrix using eigenvalue decomposition. It ensures all + eigenvalues are non-negative. The "nearest" matrix is defined in terms + of the Frobenius norm. + + Parameters + ---------- + A : np.ndarray + Input square matrix. + + Returns + ------- + np.ndarray + The nearest positive semi-definite matrix to the input matrix. + """ + C = (A + A.T) / 2 + eigval, eigvec = np.linalg.eig(C) + eigval[eigval < 0] = 0 + + return eigvec @ np.diag(eigval) @ eigvec.T + + +def _get_unravel_rv_info(optimized_point, variables, model): + cursor = 0 + slices = {} + out_shapes = {} + + for i, var in enumerate(variables): + raveled_shape = np.prod(optimized_point[var.name].shape).astype(int) + rv = model.values_to_rvs.get(var, var) + + idx = slice(cursor, cursor + raveled_shape) + slices[rv] = idx + out_shapes[rv] = tuple( + [len(model.coords[dim]) for dim in model.named_vars_to_dims.get(rv.name, [])] + ) + cursor += raveled_shape + + return slices, out_shapes + + +def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model, chains, draws): + X = pt.tensor("transformed_draws", shape=(chains, draws, H_inv.shape[0])) + out = [] + for rv, idx in slices.items(): + f = model.rvs_to_transforms[rv] + untransformed_X = f.backward(X[..., idx]) if f is not None else X[..., idx] + + if rv in out_shapes: + new_shape = (chains, draws) + out_shapes[rv] + untransformed_X = untransformed_X.reshape(new_shape) + + out.append(untransformed_X) + + f_untransform = pytensor.function([X], out, mode="JAX") + return f_untransform(posterior_draws) + + +def fit_laplace( + optimized_point: dict[str, np.ndarray], + model: pm.Model, + chains: int = 2, + draws: int = 500, + on_bad_cov: Literal["warn", "error", "ignore"] = "ignore", + transform_samples: bool = True, + zero_tol: float = 1e-8, + diag_jitter: float | None = 1e-8, + progressbar: bool = True, + mode: str = "JAX", +) -> az.InferenceData: + """ + Compute the Laplace approximation of the posterior distribution. + + The posterior distribution will be approximated as a Gaussian + distribution centered at the posterior mode. + The covariance is the inverse of the negative Hessian matrix of + the log-posterior evaluated at the mode. + + Parameters + ---------- + optimized_point : dict[str, np.ndarray] + Local maximum a posteriori (MAP) point returned from pymc.find_MAP + or jax_tools.fit_map + model : Model + A PyMC model + chains : int + The number of sampling chains running in parallel. Default is 2. + draws : int + The number of samples to draw from the approximated posterior. Default is 500. + on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore' + What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite. + If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned. + If 'error', an error will be raised. + transform_samples : bool + Whether to transform the samples back to the original parameter space. Default is True. + zero_tol: float + Value below which an element of the Hessian matrix is counted as 0. + This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8. + diag_jitter: float | None + A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite. + If None, no jitter is added. Default is 1e-8. + progressbar : bool + Whether or not to display progress bar. Default is True. + mode : str + Computation backend mode. Default is "JAX". + + Returns + ------- + InferenceData + arviz.InferenceData object storing posterior, observed_data, and constant_data groups. + + """ + frozen_model = freeze_dims_and_data(model) + if not transform_samples: + untransformed_model = remove_value_transforms(frozen_model) + logp = untransformed_model.logp(jacobian=False) + variables = untransformed_model.continuous_value_vars + else: + logp = frozen_model.logp(jacobian=True) + variables = frozen_model.continuous_value_vars + + mu = np.concatenate( + [np.atleast_1d(optimized_point[var.name]).ravel() for var in variables], axis=0 + ) + + f_logp, f_grad, f_hess, f_hessp = make_jax_funcs_from_graph( + logp, + use_grad=True, + use_hess=True, + use_hessp=False, + inputs=variables, + ) + + H = f_hess(mu) + H_inv = np.linalg.pinv(np.where(np.abs(H) < zero_tol, 0, -H)) + + def stabilize(x, jitter): + return x + np.eye(x.shape[0]) * jitter + + H_inv = H_inv if diag_jitter is None else stabilize(H_inv, diag_jitter) + + try: + np.linalg.cholesky(H_inv) + except np.linalg.LinAlgError: + if on_bad_cov == "error": + raise np.linalg.LinAlgError( + "Inverse Hessian not positive-semi definite at the provided point" + ) + H_inv = get_near_psd(H_inv) + if on_bad_cov == "warn": + _log.warning( + "Inverse Hessian is not positive semi-definite at the provided point, using the closest PSD " + "matrix in L1-norm instead" + ) + + posterior_dist = stats.multivariate_normal(mean=mu, cov=H_inv, allow_singular=True) + posterior_draws = posterior_dist.rvs(size=(chains, draws)) + slices, out_shapes = _get_unravel_rv_info(optimized_point, variables, frozen_model) + + if transform_samples: + posterior_draws = _create_transformed_draws( + H_inv, slices, out_shapes, posterior_draws, frozen_model, chains, draws + ) + else: + posterior_draws = [ + posterior_draws[..., idx].reshape((chains, draws, *out_shapes.get(rv, ()))) + for rv, idx in slices.items() + ] + + def make_rv_coords(rv): + coords = {"chain": range(chains), "draw": range(draws)} + extra_dims = frozen_model.named_vars_to_dims.get(rv.name) + if extra_dims is None: + return coords + return coords | {dim: list(frozen_model.coords[dim]) for dim in extra_dims} + + def make_rv_dims(rv): + dims = ["chain", "draw"] + extra_dims = frozen_model.named_vars_to_dims.get(rv.name) + if extra_dims is None: + return dims + return dims + list(extra_dims) + + idata = { + rv.name: xr.DataArray( + data=draws.squeeze(), + coords=make_rv_coords(rv), + dims=make_rv_dims(rv), + name=rv.name, + ) + for rv, draws in zip(slices.keys(), posterior_draws) + } + + coords, dims = coords_and_dims_for_inferencedata(frozen_model) + idata = az.convert_to_inference_data(idata, coords=coords, dims=dims) + + if frozen_model.deterministics: + idata.posterior = pm.compute_deterministics( + idata.posterior, + model=frozen_model, + merge_dataset=True, + progressbar=progressbar, + compile_kwargs={"mode": mode}, + ) + + observed_data = dict_to_dataset( + find_observations(frozen_model), + library=pm, + coords=coords, + dims=dims, + default_dims=[], + ) + + constant_data = dict_to_dataset( + find_constants(frozen_model), + library=pm, + coords=coords, + dims=dims, + default_dims=[], + ) + + idata.add_groups( + {"observed_data": observed_data, "constant_data": constant_data}, + coords=coords, + dims=dims, + ) + + return idata + + +def make_jax_funcs_from_graph( + graph: TensorVariable, + use_grad: bool, + use_hess: bool, + use_hessp: bool, + inputs: list[TensorVariable] | None = None, +) -> tuple[Callable, ...]: + if inputs is None: + from pymc.pytensorf import inputvars + + inputs = inputvars(graph) + if not isinstance(inputs, list): + inputs = [inputs] + + f = cast(Callable, get_jaxified_graph(inputs=inputs, outputs=[graph])) + input_shapes = [x.type.shape for x in inputs] + + def at_least_tuple(x): + if isinstance(x, tuple | list): + return x + return (x,) + + assert all([xi is not None for x in input_shapes for xi in at_least_tuple(x)]) + + def f_jax(x): + args = [] + cursor = 0 + for shape in input_shapes: + n_elements = int(np.prod(shape)) + s = slice(cursor, cursor + n_elements) + args.append(x[s].reshape(shape)) + cursor += n_elements + return f(*args)[0] + + f_logp = jax.jit(f_jax) + + f_grad = None + f_hess = None + f_hessp = None + + if use_grad: + _f_grad_jax = jax.grad(f_jax) + + def f_grad_jax(x): + return jax.numpy.stack(_f_grad_jax(x)) + + f_grad = jax.jit(f_grad_jax) + + if use_hessp: + if not use_grad: + raise ValueError("Cannot ask for Hessian without asking for Gradients") + + def f_hessp_jax(x, p): + y, u = jax.jvp(f_grad_jax, (x,), (p,)) + return jax.numpy.stack(u) + + f_hessp = jax.jit(f_hessp_jax) + + if use_hess: + if not use_grad: + raise ValueError("Cannot ask for Hessian without asking for Gradients") + _f_hess_jax = jax.jacfwd(f_grad_jax) + + def f_hess_jax(x): + return jax.numpy.stack(_f_hess_jax(x)) + + f_hess = jax.jit(f_hess_jax) + + return f_logp, f_grad, f_hess, f_hessp + + +def find_MAP( + model: pm.Model, + method: minimize_method, + use_grad: bool | None = None, + use_hessp: bool | None = None, + use_hess: bool | None = None, + initvals: dict | None = None, + random_seed: int | np.random.Generator | None = None, + return_raw: bool = False, + jitter_rvs: list[TensorVariable] | None = None, + progressbar: bool = True, + include_transformed: bool = True, + **optimizer_kwargs, +) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], OptimizeResult]: + """ + Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.minimize. + + Parameters + ---------- + model : pm.Model + The PyMC model to be fitted. + method : str + The optimization method to use. See scipy.optimize.minimize documentation for details. + use_grad : bool | None, optional + Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on + the ``method``. + use_hessp : bool | None, optional + Whether to use Hessian-vector products in the optimization. Defaults to None, which determines this automatically based on + the ``method``. + use_hess : bool | None, optional + Whether to use the Hessian matrix in the optimization. Defaults to None, which determines this automatically based on + the ``method``. + initvals : None | dict, optional + Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted. + If None, the model's default initial values are used. + random_seed : None | int | np.random.Generator, optional + Seed for the random number generator or a numpy Generator for reproducibility + return_raw: bool | False, optinal + Whether to also return the full output of `scipy.optimize.minimize` + jitter : bool, optional + Whether to add jitter to the initial values. Defaults to False. + progressbar : bool, optional + Whether to display a progress bar during optimization. Defaults to True. + include_transformed: bool, optional + Whether to include transformed variable values in the returned dictionary. Defaults to True. + **optimizer_kwargs + Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function. + + Returns + ------- + optimizer_result: dict[str, np.ndarray] or tuple[dict[str, np.ndarray], OptimizerResult] + Dictionary with names of random variables as keys, and optimization results as values. If return_raw is True, + also returns the object returned by ``scipy.optimize.minimize``. + """ + frozen_model = freeze_dims_and_data(model) + + if jitter_rvs is None: + jitter_rvs = [] + + ipfn = make_initial_point_fn( + model=frozen_model, + jitter_rvs=set(jitter_rvs), + return_transformed=True, + overrides=initvals, + ) + + start_dict = ipfn(random_seed) + vars_dict = {var.name: var for var in frozen_model.continuous_value_vars} + initial_params = DictToArrayBijection.map( + {var_name: value for var_name, value in start_dict.items() if var_name in vars_dict} + ) + + inputs = [frozen_model.values_to_rvs[vars_dict[x]] for x in start_dict.keys()] + inputs = [frozen_model.rvs_to_values[x] for x in inputs] + + logp_factors = frozen_model.logp(sum=False, jacobian=False) + neg_logp = -pt.sum([pt.sum(factor) for factor in logp_factors]) + + f_logp, f_grad, f_hess, f_hessp = make_jax_funcs_from_graph( + neg_logp, use_grad, use_hess, use_hessp, inputs=inputs + ) + + args = optimizer_kwargs.pop("args", None) + + optimizer_result = minimize( + f=f_logp, + x0=cast(np.ndarray[float], initial_params.data), + args=args, + jac=f_grad, + hess=f_hess, + hessp=f_hessp, + progressbar=progressbar, + method=method, + **optimizer_kwargs, + ) + + initial_point = RaveledVars(optimizer_result.x, initial_params.point_map_info) + unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed) + unobserved_vars_values = model.compile_fn(unobserved_vars)( + DictToArrayBijection.rmap(initial_point, start_dict) + ) + optimized_point = { + var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values) + } + + if return_raw: + return optimized_point, optimizer_result + + return optimized_point diff --git a/tests/test_jax_find_map.py b/tests/test_jax_find_map.py new file mode 100644 index 00000000..dd70501f --- /dev/null +++ b/tests/test_jax_find_map.py @@ -0,0 +1,186 @@ +import numpy as np +import pymc as pm +import pytensor.tensor as pt +import pytest + +from inference.jax_find_map import find_MAP, fit_laplace, make_jax_funcs_from_graph + +pytest.importorskip("jax") + + +@pytest.fixture(scope="session") +def rng(): + seed = sum(map(ord, "test_fit_map")) + return np.random.default_rng(seed) + + +def test_jax_functions_from_graph(): + x = pt.tensor("x", shape=(2,)) + + def compute_z(x): + z1 = x[0] ** 2 + 2 + z2 = x[0] * x[1] + 3 + return z1, z2 + + z = pt.stack(compute_z(x)) + f_z, f_grad, f_hess, f_hessp = make_jax_funcs_from_graph( + z.sum(), use_grad=True, use_hess=True, use_hessp=True + ) + + x_val = np.array([1.0, 2.0]) + expected_z = sum(compute_z(x_val)) + + z_jax = f_z(x_val) + np.testing.assert_allclose(z_jax, expected_z) + + grad_val = np.array(f_grad(x_val)) + np.testing.assert_allclose(grad_val.squeeze(), np.array([2 * x_val[0] + x_val[1], x_val[0]])) + + hess_val = np.array(f_hess(x_val)) + np.testing.assert_allclose(hess_val.squeeze(), np.array([[2, 1], [1, 0]])) + + hessp_val = np.array(f_hessp(x_val, np.array([1.0, 0.0]))) + np.testing.assert_allclose(hessp_val.squeeze(), np.array([2, 1])) + + +@pytest.mark.parametrize( + "method, use_grad, use_hess", + [ + ("nelder-mead", False, False), + ("powell", False, False), + ("CG", True, False), + ("BFGS", True, False), + ("L-BFGS-B", True, False), + ("TNC", True, False), + ("SLSQP", True, False), + ("dogleg", True, True), + ("trust-ncg", True, True), + ("trust-exact", True, True), + ("trust-krylov", True, True), + ("trust-constr", True, True), + ], +) +def test_JAX_map(method, use_grad, use_hess, rng): + with pm.Model() as m: + mu = pm.Normal("mu") + sigma = pm.Exponential("sigma", 1) + pm.Normal("y_hat", mu=mu, sigma=sigma, observed=rng.normal(loc=3, scale=1.5, size=100)) + + extra_kwargs = {} + if method == "dogleg": + # HACK -- dogleg requires that the hessian of the objective function is PSD, so we have to pick a point + # where this is true + extra_kwargs = {"initvals": {"mu": 2, "sigma_log__": 1}} + + optimized_point = find_MAP( + m, method, **extra_kwargs, use_grad=use_grad, use_hess=use_hess, progressbar=False + ) + mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"] + + assert np.isclose(mu_hat, 3, atol=0.5) + assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5) + + +@pytest.mark.parametrize( + "transform_samples", + [True, False], + ids=["transformed", "untransformed"], +) +def test_fit_laplace_coords(rng, transform_samples): + coords = {"city": ["A", "B", "C"], "obs_idx": np.arange(100)} + with pm.Model(coords=coords) as model: + mu = pm.Normal("mu", mu=3, sigma=0.5, dims=["city"]) + sigma = pm.Normal("sigma", mu=1.5, sigma=0.5, dims=["city"]) + obs = pm.Normal( + "obs", + mu=mu, + sigma=sigma, + observed=rng.normal(loc=3, scale=1.5, size=(100, 3)), + dims=["obs_idx", "city"], + ) + optimized_point = find_MAP( + model, + "Newton-CG", + use_grad=True, + progressbar=False, + ) + + for value in optimized_point.values(): + assert value.shape == (3,) + + idata = fit_laplace( + optimized_point, + model, + transform_samples=transform_samples, + progressbar=False, + ) + + np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2, 3), 3), atol=0.3) + np.testing.assert_allclose( + np.mean(idata.posterior.sigma, axis=1), np.full((2, 3), 1.5), atol=0.3 + ) + + +def test_fit_laplace_ragged_coords(rng): + coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)} + with pm.Model(coords=coords) as ragged_dim_model: + X = pm.Data("X", np.ones((100, 2)), dims=["obs_idx", "feature"]) + beta = pm.Normal( + "beta", mu=[[-100.0, 100.0], [-100.0, 100.0], [-100.0, 100.0]], dims=["city", "feature"] + ) + mu = pm.Deterministic( + "mu", (X[:, None, :] * beta[None]).sum(axis=-1), dims=["obs_idx", "feature"] + ) + sigma = pm.Normal("sigma", mu=1.5, sigma=0.5, dims=["city"]) + + obs = pm.Normal( + "obs", + mu=mu, + sigma=sigma, + observed=rng.normal(loc=3, scale=1.5, size=(100, 3)), + dims=["obs_idx", "city"], + ) + + optimized_point, _ = find_MAP( + ragged_dim_model, "Newton-CG", use_grad=True, progressbar=False, return_raw=True + ) + + idata = fit_laplace(optimized_point, ragged_dim_model, progressbar=False) + + assert idata["posterior"].beta.shape[-2:] == (3, 2) + assert idata["posterior"].sigma.shape[-1:] == (3,) + + # Check that everything got unraveled correctly -- feature 0 should be strictly negative, feature 1 + # strictly positive + assert (idata["posterior"].beta.sel(feature=0).to_numpy() < 0).all() + assert (idata["posterior"].beta.sel(feature=1).to_numpy() > 0).all() + + +@pytest.mark.parametrize( + "transform_samples", + [True, False], + ids=["transformed", "untransformed"], +) +def test_fit_laplace(transform_samples): + with pm.Model() as simp_model: + mu = pm.Normal("mu", mu=3, sigma=0.5) + sigma = pm.Normal("sigma", mu=1.5, sigma=0.5) + obs = pm.Normal( + "obs", + mu=mu, + sigma=sigma, + observed=np.random.default_rng().normal(loc=3, scale=1.5, size=(10000,)), + ) + + optimized_point = find_MAP( + simp_model, + "Newton-CG", + use_grad=True, + progressbar=False, + ) + idata = fit_laplace( + optimized_point, simp_model, transform_samples=transform_samples, progressbar=False + ) + + np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2,), 3), atol=0.1) + np.testing.assert_allclose(np.mean(idata.posterior.sigma, axis=1), np.full((2,), 1.5), atol=0.1) From 7ed3b2f261dbd9ebff1774e3cfebed3bbdf4f238 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 27 Oct 2024 14:18:31 +0800 Subject: [PATCH 02/21] add `better_optimize` to CI envs --- conda-envs/environment-test.yml | 1 + conda-envs/windows-environment-test.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 4deda063..0723431c 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -13,3 +13,4 @@ dependencies: - pymc>=5.17.0 # CI was failing to resolve - blackjax - scikit-learn + - better_optimize diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 4deda063..0723431c 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -13,3 +13,4 @@ dependencies: - pymc>=5.17.0 # CI was failing to resolve - blackjax - scikit-learn + - better_optimize From e412f6f264f9a26e1674039dfb1cb42eb922b28f Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 27 Oct 2024 14:37:37 +0800 Subject: [PATCH 03/21] Fix relative import --- tests/test_jax_find_map.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_jax_find_map.py b/tests/test_jax_find_map.py index dd70501f..f3f9d0aa 100644 --- a/tests/test_jax_find_map.py +++ b/tests/test_jax_find_map.py @@ -3,7 +3,11 @@ import pytensor.tensor as pt import pytest -from inference.jax_find_map import find_MAP, fit_laplace, make_jax_funcs_from_graph +from pymc_experimental.inference.jax_find_map import ( + find_MAP, + fit_laplace, + make_jax_funcs_from_graph, +) pytest.importorskip("jax") From f9b625872166121b7a39579718d2840a3e5e3add Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 27 Oct 2024 14:41:35 +0800 Subject: [PATCH 04/21] Remove `find_MAP` import from module-level `__init__.py` --- pymc_experimental/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymc_experimental/__init__.py b/pymc_experimental/__init__.py index 977e593c..c19097e6 100644 --- a/pymc_experimental/__init__.py +++ b/pymc_experimental/__init__.py @@ -16,7 +16,6 @@ from pymc_experimental import gp, statespace, utils from pymc_experimental.distributions import * from pymc_experimental.inference.fit import fit -from pymc_experimental.inference.jax_find_map import find_MAP from pymc_experimental.model.marginal.marginal_model import MarginalModel, marginalize from pymc_experimental.model.model_api import as_model from pymc_experimental.version import __version__ From ad3abd9e276e7c35a7257803d21643417eb91573 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sun, 27 Oct 2024 16:57:19 +0800 Subject: [PATCH 05/21] Update docstring --- pymc_experimental/inference/jax_find_map.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc_experimental/inference/jax_find_map.py b/pymc_experimental/inference/jax_find_map.py index b1a41649..e9f44779 100644 --- a/pymc_experimental/inference/jax_find_map.py +++ b/pymc_experimental/inference/jax_find_map.py @@ -162,7 +162,7 @@ def fit_laplace( ) f_logp, f_grad, f_hess, f_hessp = make_jax_funcs_from_graph( - logp, + cast(TensorVariable, logp), use_grad=True, use_hess=True, use_hessp=False, @@ -376,8 +376,8 @@ def find_MAP( Seed for the random number generator or a numpy Generator for reproducibility return_raw: bool | False, optinal Whether to also return the full output of `scipy.optimize.minimize` - jitter : bool, optional - Whether to add jitter to the initial values. Defaults to False. + jitter_rvs : list of TensorVariables, optional + Variables whose initial values should be jittered. If None, all variables are jittered. progressbar : bool, optional Whether to display a progress bar during optimization. Defaults to True. include_transformed: bool, optional From be1d790fc4304996f1ebe1dcbb28c39e7a5cf425 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sun, 27 Oct 2024 17:10:25 +0800 Subject: [PATCH 06/21] Allow calling `find_MAP` inside model context without model argument --- pymc_experimental/inference/jax_find_map.py | 6 ++- tests/test_jax_find_map.py | 46 ++++++++++----------- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/pymc_experimental/inference/jax_find_map.py b/pymc_experimental/inference/jax_find_map.py index e9f44779..ac21234d 100644 --- a/pymc_experimental/inference/jax_find_map.py +++ b/pymc_experimental/inference/jax_find_map.py @@ -338,8 +338,9 @@ def f_hess_jax(x): def find_MAP( - model: pm.Model, method: minimize_method, + *, + model: pm.Model | None = None, use_grad: bool | None = None, use_hessp: bool | None = None, use_hess: bool | None = None, @@ -357,7 +358,7 @@ def find_MAP( Parameters ---------- model : pm.Model - The PyMC model to be fitted. + The PyMC model to be fit. If None, the current model context is used. method : str The optimization method to use. See scipy.optimize.minimize documentation for details. use_grad : bool | None, optional @@ -391,6 +392,7 @@ def find_MAP( Dictionary with names of random variables as keys, and optimization results as values. If return_raw is True, also returns the object returned by ``scipy.optimize.minimize``. """ + model = pm.modelcontext(model) frozen_model = freeze_dims_and_data(model) if jitter_rvs is None: diff --git a/tests/test_jax_find_map.py b/tests/test_jax_find_map.py index f3f9d0aa..568a6a07 100644 --- a/tests/test_jax_find_map.py +++ b/tests/test_jax_find_map.py @@ -65,20 +65,20 @@ def compute_z(x): ], ) def test_JAX_map(method, use_grad, use_hess, rng): - with pm.Model() as m: - mu = pm.Normal("mu") - sigma = pm.Exponential("sigma", 1) - pm.Normal("y_hat", mu=mu, sigma=sigma, observed=rng.normal(loc=3, scale=1.5, size=100)) - extra_kwargs = {} if method == "dogleg": # HACK -- dogleg requires that the hessian of the objective function is PSD, so we have to pick a point # where this is true extra_kwargs = {"initvals": {"mu": 2, "sigma_log__": 1}} - optimized_point = find_MAP( - m, method, **extra_kwargs, use_grad=use_grad, use_hess=use_hess, progressbar=False - ) + with pm.Model() as m: + mu = pm.Normal("mu") + sigma = pm.Exponential("sigma", 1) + pm.Normal("y_hat", mu=mu, sigma=sigma, observed=rng.normal(loc=3, scale=1.5, size=100)) + + optimized_point = find_MAP( + method=method, **extra_kwargs, use_grad=use_grad, use_hess=use_hess, progressbar=False + ) mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"] assert np.isclose(mu_hat, 3, atol=0.5) @@ -102,12 +102,12 @@ def test_fit_laplace_coords(rng, transform_samples): observed=rng.normal(loc=3, scale=1.5, size=(100, 3)), dims=["obs_idx", "city"], ) - optimized_point = find_MAP( - model, - "Newton-CG", - use_grad=True, - progressbar=False, - ) + + optimized_point = find_MAP( + method="Newton-CG", + use_grad=True, + progressbar=False, + ) for value in optimized_point.values(): assert value.shape == (3,) @@ -145,9 +145,9 @@ def test_fit_laplace_ragged_coords(rng): dims=["obs_idx", "city"], ) - optimized_point, _ = find_MAP( - ragged_dim_model, "Newton-CG", use_grad=True, progressbar=False, return_raw=True - ) + optimized_point, _ = find_MAP( + method="Newton-CG", use_grad=True, progressbar=False, return_raw=True + ) idata = fit_laplace(optimized_point, ragged_dim_model, progressbar=False) @@ -176,12 +176,12 @@ def test_fit_laplace(transform_samples): observed=np.random.default_rng().normal(loc=3, scale=1.5, size=(10000,)), ) - optimized_point = find_MAP( - simp_model, - "Newton-CG", - use_grad=True, - progressbar=False, - ) + optimized_point = find_MAP( + method="Newton-CG", + use_grad=True, + progressbar=False, + ) + idata = fit_laplace( optimized_point, simp_model, transform_samples=transform_samples, progressbar=False ) From 923eb26f43c26c4a2589c071cd3d745d78cf0a61 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sun, 27 Oct 2024 17:12:16 +0800 Subject: [PATCH 07/21] Required patched better_optimize --- conda-envs/environment-test.yml | 2 +- conda-envs/windows-environment-test.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 0723431c..2c84fb6e 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -13,4 +13,4 @@ dependencies: - pymc>=5.17.0 # CI was failing to resolve - blackjax - scikit-learn - - better_optimize + - better_optimize>=0.0.10 diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 0723431c..2c84fb6e 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -13,4 +13,4 @@ dependencies: - pymc>=5.17.0 # CI was failing to resolve - blackjax - scikit-learn - - better_optimize + - better_optimize>=0.0.10 From f705d43792c75df3bc518207f2f023c421b252ba Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sun, 1 Dec 2024 02:21:51 +0800 Subject: [PATCH 08/21] in-progress refactor --- pymc_experimental/inference/jax_find_map.py | 227 +++++++++++--------- 1 file changed, 123 insertions(+), 104 deletions(-) diff --git a/pymc_experimental/inference/jax_find_map.py b/pymc_experimental/inference/jax_find_map.py index ac21234d..d9a80079 100644 --- a/pymc_experimental/inference/jax_find_map.py +++ b/pymc_experimental/inference/jax_find_map.py @@ -23,6 +23,7 @@ from pymc.initial_point import make_initial_point_fn from pymc.model.transform.conditioning import remove_value_transforms from pymc.model.transform.optimization import freeze_dims_and_data +from pymc.pytensorf import join_nonshared_inputs from pymc.sampling.jax import get_jaxified_graph from pymc.util import get_default_varnames from pytensor.tensor import TensorVariable @@ -32,13 +33,12 @@ _log = logging.getLogger(__name__) -def get_near_psd(A: np.ndarray) -> np.ndarray: +def get_nearest_psd(A: np.ndarray) -> np.ndarray: """ Compute the nearest positive semi-definite matrix to a given matrix. - This function takes a square matrix and returns the nearest positive - semi-definite matrix using eigenvalue decomposition. It ensures all - eigenvalues are non-negative. The "nearest" matrix is defined in terms + This function takes a square matrix and returns the nearest positive semi-definite matrix using + eigenvalue decomposition. It ensures all eigenvalues are non-negative. The "nearest" matrix is defined in terms of the Frobenius norm. Parameters @@ -58,23 +58,13 @@ def get_near_psd(A: np.ndarray) -> np.ndarray: return eigvec @ np.diag(eigval) @ eigvec.T -def _get_unravel_rv_info(optimized_point, variables, model): - cursor = 0 - slices = {} - out_shapes = {} - - for i, var in enumerate(variables): - raveled_shape = np.prod(optimized_point[var.name].shape).astype(int) - rv = model.values_to_rvs.get(var, var) - - idx = slice(cursor, cursor + raveled_shape) - slices[rv] = idx - out_shapes[rv] = tuple( - [len(model.coords[dim]) for dim in model.named_vars_to_dims.get(rv.name, [])] - ) - cursor += raveled_shape +def _unconstrained_vector_to_constrained_rvs(model): + constrained_rvs, unconstrained_vector = join_nonshared_inputs( + model.initial_point(), inputs=model.value_vars, outputs=model.unobserved_value_vars + ) - return slices, out_shapes + unconstrained_vector.name = "unconstrained_vector" + return constrained_rvs, unconstrained_vector def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model, chains, draws): @@ -94,37 +84,24 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model, return f_untransform(posterior_draws) -def fit_laplace( +def jax_fit_mvn_to_MAP( optimized_point: dict[str, np.ndarray], model: pm.Model, - chains: int = 2, - draws: int = 500, on_bad_cov: Literal["warn", "error", "ignore"] = "ignore", transform_samples: bool = True, zero_tol: float = 1e-8, diag_jitter: float | None = 1e-8, - progressbar: bool = True, - mode: str = "JAX", -) -> az.InferenceData: +) -> tuple[RaveledVars, np.ndarray]: """ - Compute the Laplace approximation of the posterior distribution. - - The posterior distribution will be approximated as a Gaussian - distribution centered at the posterior mode. - The covariance is the inverse of the negative Hessian matrix of - the log-posterior evaluated at the mode. + Create a multivariate normal distribution using the inverse of the negative Hessian matrix of the log-posterior + evaluated at the MAP estimate. This is the basis of the Laplace approximation. Parameters ---------- optimized_point : dict[str, np.ndarray] - Local maximum a posteriori (MAP) point returned from pymc.find_MAP - or jax_tools.fit_map + Local maximum a posteriori (MAP) point returned from pymc.find_MAP or jax_tools.fit_map model : Model A PyMC model - chains : int - The number of sampling chains running in parallel. Default is 2. - draws : int - The number of samples to draw from the approximated posterior. Default is 500. on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore' What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite. If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned. @@ -137,18 +114,17 @@ def fit_laplace( diag_jitter: float | None A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite. If None, no jitter is added. Default is 1e-8. - progressbar : bool - Whether or not to display progress bar. Default is True. - mode : str - Computation backend mode. Default is "JAX". Returns ------- - InferenceData - arviz.InferenceData object storing posterior, observed_data, and constant_data groups. + map_estimate: RaveledVars + The MAP estimate of the model parameters, raveled into a 1D array. + inverse_hessian: np.ndarray + The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate. """ frozen_model = freeze_dims_and_data(model) + if not transform_samples: untransformed_model = remove_value_transforms(frozen_model) logp = untransformed_model.logp(jacobian=False) @@ -157,19 +133,17 @@ def fit_laplace( logp = frozen_model.logp(jacobian=True) variables = frozen_model.continuous_value_vars - mu = np.concatenate( - [np.atleast_1d(optimized_point[var.name]).ravel() for var in variables], axis=0 + mu = DictToArrayBijection.map(optimized_point) + + [neg_logp], flat_inputs = join_nonshared_inputs( + point=frozen_model.initial_point(), outputs=[-logp], inputs=variables ) f_logp, f_grad, f_hess, f_hessp = make_jax_funcs_from_graph( - cast(TensorVariable, logp), - use_grad=True, - use_hess=True, - use_hessp=False, - inputs=variables, + neg_logp, use_grad=True, use_hess=True, use_hessp=False, inputs=[flat_inputs] ) - H = f_hess(mu) + H = -f_hess(mu.data) H_inv = np.linalg.pinv(np.where(np.abs(H) < zero_tol, 0, -H)) def stabilize(x, jitter): @@ -184,65 +158,103 @@ def stabilize(x, jitter): raise np.linalg.LinAlgError( "Inverse Hessian not positive-semi definite at the provided point" ) - H_inv = get_near_psd(H_inv) + H_inv = get_nearest_psd(H_inv) if on_bad_cov == "warn": _log.warning( "Inverse Hessian is not positive semi-definite at the provided point, using the closest PSD " "matrix in L1-norm instead" ) - posterior_dist = stats.multivariate_normal(mean=mu, cov=H_inv, allow_singular=True) + return mu, H_inv + + +def jax_laplace( + mu: RaveledVars, + H_inv: np.ndarray, + model: pm.Model, + chains: int = 2, + draws: int = 500, + transform_samples: bool = True, + progressbar: bool = True, +) -> az.InferenceData: + """ + + Parameters + ---------- + mu + H_inv + model : Model + A PyMC model + chains : int + The number of sampling chains running in parallel. Default is 2. + draws : int + The number of samples to draw from the approximated posterior. Default is 500. + transform_samples : bool + Whether to transform the samples back to the original parameter space. Default is True. + + Returns + ------- + idata: az.InferenceData + An InferenceData object containing the approximated posterior samples. + """ + posterior_dist = stats.multivariate_normal(mean=mu.data, cov=H_inv, allow_singular=True) posterior_draws = posterior_dist.rvs(size=(chains, draws)) - slices, out_shapes = _get_unravel_rv_info(optimized_point, variables, frozen_model) if transform_samples: - posterior_draws = _create_transformed_draws( - H_inv, slices, out_shapes, posterior_draws, frozen_model, chains, draws - ) + constrained_rvs, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(model) + f_constrain = get_jaxified_graph(inputs=[unconstrained_vector], outputs=constrained_rvs) + + posterior_draws = jax.jit(jax.vmap(jax.vmap(f_constrain)))(posterior_draws) + else: + info = mu.point_map_info + flat_shapes = [np.prod(shape).astype(int) for _, shape, _ in info] + slices = [ + slice(sum(flat_shapes[:i]), sum(flat_shapes[: i + 1])) for i in range(len(flat_shapes)) + ] + posterior_draws = [ - posterior_draws[..., idx].reshape((chains, draws, *out_shapes.get(rv, ()))) - for rv, idx in slices.items() + posterior_draws[..., idx].reshape((chains, draws, *shape)).astype(dtype) + for idx, (name, shape, dtype) in zip(slices, info) ] - def make_rv_coords(rv): + def make_rv_coords(name): coords = {"chain": range(chains), "draw": range(draws)} - extra_dims = frozen_model.named_vars_to_dims.get(rv.name) + extra_dims = model.named_vars_to_dims.get(name) if extra_dims is None: return coords - return coords | {dim: list(frozen_model.coords[dim]) for dim in extra_dims} + return coords | {dim: list(model.coords[dim]) for dim in extra_dims} - def make_rv_dims(rv): + def make_rv_dims(name): dims = ["chain", "draw"] - extra_dims = frozen_model.named_vars_to_dims.get(rv.name) + extra_dims = model.named_vars_to_dims.get(name) if extra_dims is None: return dims return dims + list(extra_dims) idata = { - rv.name: xr.DataArray( + name: xr.DataArray( data=draws.squeeze(), - coords=make_rv_coords(rv), - dims=make_rv_dims(rv), - name=rv.name, + coords=make_rv_coords(name), + dims=make_rv_dims(name), + name=name, ) - for rv, draws in zip(slices.keys(), posterior_draws) + for (name, _, _), draws in zip(mu.point_map_info, posterior_draws) } - coords, dims = coords_and_dims_for_inferencedata(frozen_model) + coords, dims = coords_and_dims_for_inferencedata(model) idata = az.convert_to_inference_data(idata, coords=coords, dims=dims) - if frozen_model.deterministics: + if model.deterministics: idata.posterior = pm.compute_deterministics( idata.posterior, - model=frozen_model, + model=model, merge_dataset=True, progressbar=progressbar, - compile_kwargs={"mode": mode}, ) observed_data = dict_to_dataset( - find_observations(frozen_model), + find_observations(model), library=pm, coords=coords, dims=dims, @@ -250,7 +262,7 @@ def make_rv_dims(rv): ) constant_data = dict_to_dataset( - find_constants(frozen_model), + find_constants(model), library=pm, coords=coords, dims=dims, @@ -266,6 +278,29 @@ def make_rv_dims(rv): return idata +def fit_laplace( + optimized_point: dict[str, np.ndarray], + model: pm.Model, + chains: int = 2, + draws: int = 500, + on_bad_cov: Literal["warn", "error", "ignore"] = "ignore", + transform_samples: bool = True, + zero_tol: float = 1e-8, + diag_jitter: float | None = 1e-8, + progressbar: bool = True, +) -> az.InferenceData: + mu, H_inv = jax_fit_mvn_to_MAP( + optimized_point, + model, + on_bad_cov, + transform_samples, + zero_tol, + diag_jitter, + ) + + return jax_laplace(mu, H_inv, model, chains, draws, transform_samples, progressbar) + + def make_jax_funcs_from_graph( graph: TensorVariable, use_grad: bool, @@ -280,34 +315,19 @@ def make_jax_funcs_from_graph( if not isinstance(inputs, list): inputs = [inputs] - f = cast(Callable, get_jaxified_graph(inputs=inputs, outputs=[graph])) - input_shapes = [x.type.shape for x in inputs] - - def at_least_tuple(x): - if isinstance(x, tuple | list): - return x - return (x,) + f_tuple = cast(Callable, get_jaxified_graph(inputs=inputs, outputs=[graph])) - assert all([xi is not None for x in input_shapes for xi in at_least_tuple(x)]) + def f(*args, **kwargs): + return f_tuple(*args, **kwargs)[0] - def f_jax(x): - args = [] - cursor = 0 - for shape in input_shapes: - n_elements = int(np.prod(shape)) - s = slice(cursor, cursor + n_elements) - args.append(x[s].reshape(shape)) - cursor += n_elements - return f(*args)[0] - - f_logp = jax.jit(f_jax) + f_logp = jax.jit(f) f_grad = None f_hess = None f_hessp = None if use_grad: - _f_grad_jax = jax.grad(f_jax) + _f_grad_jax = jax.grad(f) def f_grad_jax(x): return jax.numpy.stack(_f_grad_jax(x)) @@ -411,14 +431,12 @@ def find_MAP( {var_name: value for var_name, value in start_dict.items() if var_name in vars_dict} ) - inputs = [frozen_model.values_to_rvs[vars_dict[x]] for x in start_dict.keys()] - inputs = [frozen_model.rvs_to_values[x] for x in inputs] - - logp_factors = frozen_model.logp(sum=False, jacobian=False) - neg_logp = -pt.sum([pt.sum(factor) for factor in logp_factors]) + [neg_logp], inputs = join_nonshared_inputs( + point=start_dict, outputs=[-frozen_model.logp()], inputs=frozen_model.continuous_value_vars + ) f_logp, f_grad, f_hess, f_hessp = make_jax_funcs_from_graph( - neg_logp, use_grad, use_hess, use_hessp, inputs=inputs + neg_logp, use_grad, use_hess, use_hessp, inputs=[inputs] ) args = optimizer_kwargs.pop("args", None) @@ -435,11 +453,12 @@ def find_MAP( **optimizer_kwargs, ) - initial_point = RaveledVars(optimizer_result.x, initial_params.point_map_info) + raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info) unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed) unobserved_vars_values = model.compile_fn(unobserved_vars)( - DictToArrayBijection.rmap(initial_point, start_dict) + DictToArrayBijection.rmap(raveled_optimized) ) + optimized_point = { var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values) } From a23762bcdf4ac5e1145777d139653ba3766f63dc Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Tue, 3 Dec 2024 13:25:18 +0800 Subject: [PATCH 09/21] More refactor --- pymc_experimental/inference/jax_find_map.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/pymc_experimental/inference/jax_find_map.py b/pymc_experimental/inference/jax_find_map.py index d9a80079..1554ad36 100644 --- a/pymc_experimental/inference/jax_find_map.py +++ b/pymc_experimental/inference/jax_find_map.py @@ -58,6 +58,22 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray: return eigvec @ np.diag(eigval) @ eigvec.T +def unobserved_value_vars(model): + vars = [] + transformed_rvs = [] + for rv in model.free_RVs: + value_var = model.rvs_to_values[rv] + transform = model.rvs_to_transforms[rv] + if transform is not None: + transformed_rvs.append(rv) + vars.append(value_var) + + # Remove rvs from untransformed values graph + untransformed_vars = model.replace_rvs_by_values(transformed_rvs) + + return vars + untransformed_vars + + def _unconstrained_vector_to_constrained_rvs(model): constrained_rvs, unconstrained_vector = join_nonshared_inputs( model.initial_point(), inputs=model.value_vars, outputs=model.unobserved_value_vars @@ -133,7 +149,9 @@ def jax_fit_mvn_to_MAP( logp = frozen_model.logp(jacobian=True) variables = frozen_model.continuous_value_vars - mu = DictToArrayBijection.map(optimized_point) + variable_names = {var.name for var in variables} + optimized_free_params = {k: v for k, v in optimized_point.items() if k in variable_names} + mu = DictToArrayBijection.map(optimized_free_params) [neg_logp], flat_inputs = join_nonshared_inputs( point=frozen_model.initial_point(), outputs=[-logp], inputs=variables From 2d2140363087e86b59f69f2dc80fe6112a1c1d92 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Tue, 3 Dec 2024 22:51:27 +0800 Subject: [PATCH 10/21] Generalize code to use any pytensor backend --- .../{jax_find_map.py => find_map.py} | 361 +++++++++++++----- tests/test_jax_find_map.py | 44 ++- 2 files changed, 292 insertions(+), 113 deletions(-) rename pymc_experimental/inference/{jax_find_map.py => find_map.py} (60%) diff --git a/pymc_experimental/inference/jax_find_map.py b/pymc_experimental/inference/find_map.py similarity index 60% rename from pymc_experimental/inference/jax_find_map.py rename to pymc_experimental/inference/find_map.py index 1554ad36..bcbcad23 100644 --- a/pymc_experimental/inference/jax_find_map.py +++ b/pymc_experimental/inference/find_map.py @@ -24,8 +24,8 @@ from pymc.model.transform.conditioning import remove_value_transforms from pymc.model.transform.optimization import freeze_dims_and_data from pymc.pytensorf import join_nonshared_inputs -from pymc.sampling.jax import get_jaxified_graph from pymc.util import get_default_varnames +from pytensor.compile import Function from pytensor.tensor import TensorVariable from scipy import stats from scipy.optimize import OptimizeResult @@ -58,22 +58,6 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray: return eigvec @ np.diag(eigval) @ eigvec.T -def unobserved_value_vars(model): - vars = [] - transformed_rvs = [] - for rv in model.free_RVs: - value_var = model.rvs_to_values[rv] - transform = model.rvs_to_transforms[rv] - if transform is not None: - transformed_rvs.append(rv) - vars.append(value_var) - - # Remove rvs from untransformed values graph - untransformed_vars = model.replace_rvs_by_values(transformed_rvs) - - return vars + untransformed_vars - - def _unconstrained_vector_to_constrained_rvs(model): constrained_rvs, unconstrained_vector = join_nonshared_inputs( model.initial_point(), inputs=model.value_vars, outputs=model.unobserved_value_vars @@ -100,13 +84,220 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model, return f_untransform(posterior_draws) -def jax_fit_mvn_to_MAP( +def _compile_jax_gradients( + f_loss: Function, use_hess: bool, use_hessp: bool +) -> tuple[Callable | None, Callable | None]: + """ + Compile loss function gradients using JAX. + + Parameters + ---------- + f_loss: Function + The loss function to compile gradients for. Expected to be a pytensor function that returns a scalar loss, + compiled with mode="JAX". + use_hess: bool + Whether to compile a function to compute the hessian of the loss function. + use_hessp: bool + Whether to compile a function to compute the hessian-vector product of the loss function. + + Returns + ------- + f_loss_and_grad: Callable + The compiled loss function and gradient function. + f_hess: Callable | None + The compiled hessian function, or None if use_hess is False. + f_hessp: Callable | None + The compiled hessian-vector product function, or None if use_hessp is False. + """ + f_hess = None + f_hessp = None + + orig_loss_fn = f_loss.vm.jit_fn + + @jax.jit + def loss_fn_jax_grad(x, *shared): + return jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x) + + f_loss_and_grad = loss_fn_jax_grad + + if use_hessp: + + def f_hessp_jax(x, p): + y, u = jax.jvp(lambda x: f_loss_and_grad(x)[1], (x,), (p,)) + return jax.numpy.stack(u) + + f_hessp = jax.jit(f_hessp_jax) + + if use_hess: + _f_hess_jax = jax.jacfwd(lambda x: f_loss_and_grad(x)[1]) + + def f_hess_jax(x): + return jax.numpy.stack(_f_hess_jax(x)) + + f_hess = jax.jit(f_hess_jax) + + return f_loss_and_grad, f_hess, f_hessp + + +def _compile_functions( + loss: TensorVariable, + inputs: list[TensorVariable], + compute_grad: bool, + compute_hess: bool, + compute_hessp: bool, + compile_kwargs: dict | None = None, +) -> list[Function] | list[Function, Function | None, Function | None]: + """ + Compile loss functions for use with scipy.optimize.minimize. + + Parameters + ---------- + loss: TensorVariable + The loss function to compile. + inputs: list[TensorVariable] + A single flat vector input variable, collecting all inputs to the loss function. Scipy optimize routines + expect the function signature to be f(x, *args), where x is a 1D array of parameters. + compute_grad: bool + Whether to compile a function that computes the gradients of the loss function. + compute_hess: bool + Whether to compile a function that computes the Hessian of the loss function. + compute_hessp: bool + Whether to compile a function that computes the Hessian-vector product of the loss function. + compile_kwargs: dict, optional + Additional keyword arguments to pass to the ``pm.compile_pymc`` function. + + Returns + ------- + f_loss: Function + + f_hess: Function | None + f_hessp: Function | None + """ + loss = pm.pytensorf.rewrite_pregrad(loss) + f_hess = None + f_hessp = None + + if compute_grad: + grads = pytensor.gradient.grad(loss, inputs) + grad = pt.concatenate([grad.ravel() for grad in grads]) + f_loss_and_grad = pm.compile_pymc(inputs, [loss, grad], **compile_kwargs) + else: + f_loss = pm.compile_pymc(inputs, loss, **compile_kwargs) + return [f_loss] + + if compute_hess: + hess = pytensor.gradient.jacobian(grad, inputs)[0] + f_hess = pm.compile_pymc(inputs, hess, **compile_kwargs) + + if compute_hessp: + p = pt.tensor("p", shape=inputs[0].type.shape) + hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p) + f_hessp = pm.compile_pymc([*inputs, p], hessp[0], **compile_kwargs) + + return [f_loss_and_grad, f_hess, f_hessp] + + +def scipy_optimize_funcs_from_loss( + loss: TensorVariable, + inputs: list[TensorVariable], + initial_point_dict: dict[str, np.ndarray | float | int], + use_grad: bool, + use_hess: bool, + use_hessp: bool, + use_jax_gradients: bool = False, + compile_kwargs: dict | None = None, +) -> tuple[Callable, ...]: + """ + Compile loss functions for use with scipy.optimize.minimize. + + + Parameters + ---------- + loss: TensorVariable + The loss function to compile. + inputs: list[TensorVariable] + The input variables to the loss function. + initial_point_dict: dict[str, np.ndarray | float | int] + Dictionary mapping variable names to initial values. Used to determine the shapes of the input variables. + use_grad: bool + Whether to compile a function that computes the gradients of the loss function. + use_hess: bool + Whether to compile a function that computes the Hessian of the loss function. + use_hessp: bool + Whether to compile a function that computes the Hessian-vector product of the loss function. + use_jax_gradients: bool + If True, use JAX to compute gradients. This is only possible when ``compile_kwargs["mode"]`` is set to "JAX". + compile_kwargs: + Additional keyword arguments to pass to the ``pm.compile_pymc`` function. + + Returns + ------- + f_loss: Callable + The compiled loss function. + f_hess: Callable | None + The compiled hessian function, or None if use_hess is False. + f_hessp: Callable | None + The compiled hessian-vector product function, or None if use_hessp is False. + """ + + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + + if (use_hess or use_hessp) and not use_grad: + raise ValueError( + "Cannot compute hessian or hessian-vector product without also computing the gradient" + ) + + use_jax_gradients = use_jax_gradients and use_grad + + mode = compile_kwargs.get("mode", None) + if mode is None and use_jax_gradients: + compile_kwargs["mode"] = "JAX" + elif mode != "JAX" and use_jax_gradients: + raise ValueError( + 'jax gradients can only be used when ``compile_kwargs["mode"]`` is set to "JAX"' + ) + + if not isinstance(inputs, list): + inputs = [inputs] + + [loss], flat_input = join_nonshared_inputs( + point=initial_point_dict, outputs=[loss], inputs=inputs + ) + + compute_grad = use_grad and not use_jax_gradients + compute_hess = use_hess and not use_jax_gradients + compute_hessp = use_hessp and not use_jax_gradients + + funcs = _compile_functions( + loss=loss, + inputs=[flat_input], + compute_grad=compute_grad, + compute_hess=compute_hess, + compute_hessp=compute_hessp, + compile_kwargs=compile_kwargs, + ) + + # f_loss here is f_loss_and_grad if compute_grad = True. The name is unchanged to simplify the return values + f_loss = funcs.pop(0) + f_hess = funcs.pop(0) if compute_grad else None + f_hessp = funcs.pop(0) if compute_grad else None + + if use_jax_gradients: + # f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values + f_loss, f_hess, f_hessp = _compile_jax_gradients(f_loss, use_hess, use_hessp) + + return f_loss, f_hess, f_hessp + + +def fit_mvn_to_MAP( optimized_point: dict[str, np.ndarray], model: pm.Model, on_bad_cov: Literal["warn", "error", "ignore"] = "ignore", transform_samples: bool = True, + use_jax_gradients: bool = False, zero_tol: float = 1e-8, diag_jitter: float | None = 1e-8, + compile_kwargs: dict | None = None, ) -> tuple[RaveledVars, np.ndarray]: """ Create a multivariate normal distribution using the inverse of the negative Hessian matrix of the log-posterior @@ -139,6 +330,7 @@ def jax_fit_mvn_to_MAP( inverse_hessian: np.ndarray The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate. """ + compile_kwargs = {} if compile_kwargs is None else compile_kwargs frozen_model = freeze_dims_and_data(model) if not transform_samples: @@ -153,12 +345,15 @@ def jax_fit_mvn_to_MAP( optimized_free_params = {k: v for k, v in optimized_point.items() if k in variable_names} mu = DictToArrayBijection.map(optimized_free_params) - [neg_logp], flat_inputs = join_nonshared_inputs( - point=frozen_model.initial_point(), outputs=[-logp], inputs=variables - ) - - f_logp, f_grad, f_hess, f_hessp = make_jax_funcs_from_graph( - neg_logp, use_grad=True, use_hess=True, use_hessp=False, inputs=[flat_inputs] + _, f_hess, _ = scipy_optimize_funcs_from_loss( + loss=-logp, + inputs=variables, + initial_point_dict=frozen_model.initial_point(), + use_grad=True, + use_hess=True, + use_hessp=False, + use_jax_gradients=use_jax_gradients, + compile_kwargs=compile_kwargs, ) H = -f_hess(mu.data) @@ -186,7 +381,7 @@ def stabilize(x, jitter): return mu, H_inv -def jax_laplace( +def laplace( mu: RaveledVars, H_inv: np.ndarray, model: pm.Model, @@ -194,6 +389,7 @@ def jax_laplace( draws: int = 500, transform_samples: bool = True, progressbar: bool = True, + **compile_kwargs, ) -> az.InferenceData: """ @@ -220,9 +416,19 @@ def jax_laplace( if transform_samples: constrained_rvs, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(model) - f_constrain = get_jaxified_graph(inputs=[unconstrained_vector], outputs=constrained_rvs) + batched_values = pt.tensor( + "batched_values", + shape=(chains, draws, *unconstrained_vector.type.shape), + dtype=unconstrained_vector.type.dtype, + ) + batched_rvs = pytensor.graph.vectorize_graph( + constrained_rvs, replace={unconstrained_vector: batched_values} + ) - posterior_draws = jax.jit(jax.vmap(jax.vmap(f_constrain)))(posterior_draws) + f_constrain = pm.compile_pymc( + inputs=[batched_values], outputs=batched_rvs, **compile_kwargs + ) + posterior_draws = f_constrain(posterior_draws) else: info = mu.point_map_info @@ -269,6 +475,7 @@ def make_rv_dims(name): model=model, merge_dataset=True, progressbar=progressbar, + compile_kwargs=compile_kwargs, ) observed_data = dict_to_dataset( @@ -306,73 +513,21 @@ def fit_laplace( zero_tol: float = 1e-8, diag_jitter: float | None = 1e-8, progressbar: bool = True, + compile_kwargs: dict | None = None, ) -> az.InferenceData: - mu, H_inv = jax_fit_mvn_to_MAP( - optimized_point, - model, - on_bad_cov, - transform_samples, - zero_tol, - diag_jitter, + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + + mu, H_inv = fit_mvn_to_MAP( + optimized_point=optimized_point, + model=model, + on_bad_cov=on_bad_cov, + transform_samples=transform_samples, + zero_tol=zero_tol, + diag_jitter=diag_jitter, + compile_kwargs=compile_kwargs, ) - return jax_laplace(mu, H_inv, model, chains, draws, transform_samples, progressbar) - - -def make_jax_funcs_from_graph( - graph: TensorVariable, - use_grad: bool, - use_hess: bool, - use_hessp: bool, - inputs: list[TensorVariable] | None = None, -) -> tuple[Callable, ...]: - if inputs is None: - from pymc.pytensorf import inputvars - - inputs = inputvars(graph) - if not isinstance(inputs, list): - inputs = [inputs] - - f_tuple = cast(Callable, get_jaxified_graph(inputs=inputs, outputs=[graph])) - - def f(*args, **kwargs): - return f_tuple(*args, **kwargs)[0] - - f_logp = jax.jit(f) - - f_grad = None - f_hess = None - f_hessp = None - - if use_grad: - _f_grad_jax = jax.grad(f) - - def f_grad_jax(x): - return jax.numpy.stack(_f_grad_jax(x)) - - f_grad = jax.jit(f_grad_jax) - - if use_hessp: - if not use_grad: - raise ValueError("Cannot ask for Hessian without asking for Gradients") - - def f_hessp_jax(x, p): - y, u = jax.jvp(f_grad_jax, (x,), (p,)) - return jax.numpy.stack(u) - - f_hessp = jax.jit(f_hessp_jax) - - if use_hess: - if not use_grad: - raise ValueError("Cannot ask for Hessian without asking for Gradients") - _f_hess_jax = jax.jacfwd(f_grad_jax) - - def f_hess_jax(x): - return jax.numpy.stack(_f_hess_jax(x)) - - f_hess = jax.jit(f_hess_jax) - - return f_logp, f_grad, f_hess, f_hessp + return laplace(mu, H_inv, model, chains, draws, transform_samples, progressbar) def find_MAP( @@ -388,6 +543,8 @@ def find_MAP( jitter_rvs: list[TensorVariable] | None = None, progressbar: bool = True, include_transformed: bool = True, + use_jax_gradients: bool = False, + compile_kwargs: dict | None = None, **optimizer_kwargs, ) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], OptimizeResult]: """ @@ -433,8 +590,8 @@ def find_MAP( model = pm.modelcontext(model) frozen_model = freeze_dims_and_data(model) - if jitter_rvs is None: - jitter_rvs = [] + jitter_rvs = [] if jitter_rvs is None else jitter_rvs + compile_kwargs = {} if compile_kwargs is None else compile_kwargs ipfn = make_initial_point_fn( model=frozen_model, @@ -449,21 +606,25 @@ def find_MAP( {var_name: value for var_name, value in start_dict.items() if var_name in vars_dict} ) - [neg_logp], inputs = join_nonshared_inputs( - point=start_dict, outputs=[-frozen_model.logp()], inputs=frozen_model.continuous_value_vars - ) - - f_logp, f_grad, f_hess, f_hessp = make_jax_funcs_from_graph( - neg_logp, use_grad, use_hess, use_hessp, inputs=[inputs] + f_logp, f_hess, f_hessp = scipy_optimize_funcs_from_loss( + loss=-frozen_model.logp(), + inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars, + initial_point_dict=start_dict, + use_grad=use_grad, + use_hess=use_hess, + use_hessp=use_hessp, + use_jax_gradients=use_jax_gradients, + compile_kwargs=compile_kwargs, ) args = optimizer_kwargs.pop("args", None) + # better_optimize.minimize will check if f_logp is a fused loss+grad Op, and automatically assign the jac argument + # if so. That is why it is not set here, regardless of user settings. optimizer_result = minimize( f=f_logp, x0=cast(np.ndarray[float], initial_params.data), args=args, - jac=f_grad, hess=f_hess, hessp=f_hessp, progressbar=progressbar, diff --git a/tests/test_jax_find_map.py b/tests/test_jax_find_map.py index 568a6a07..01e05301 100644 --- a/tests/test_jax_find_map.py +++ b/tests/test_jax_find_map.py @@ -3,10 +3,10 @@ import pytensor.tensor as pt import pytest -from pymc_experimental.inference.jax_find_map import ( +from pymc_experimental.inference.find_map import ( find_MAP, fit_laplace, - make_jax_funcs_from_graph, + scipy_optimize_funcs_from_loss, ) pytest.importorskip("jax") @@ -18,7 +18,8 @@ def rng(): return np.random.default_rng(seed) -def test_jax_functions_from_graph(): +@pytest.mark.parametrize("use_jax_gradients", [True, False], ids=["jax_grad", "pt_grad"]) +def test_jax_functions_from_graph(use_jax_gradients): x = pt.tensor("x", shape=(2,)) def compute_z(x): @@ -27,17 +28,22 @@ def compute_z(x): return z1, z2 z = pt.stack(compute_z(x)) - f_z, f_grad, f_hess, f_hessp = make_jax_funcs_from_graph( - z.sum(), use_grad=True, use_hess=True, use_hessp=True + f_loss, f_hess, f_hessp = scipy_optimize_funcs_from_loss( + loss=z.sum(), + inputs=[x], + initial_point_dict={"x": np.array([1.0, 2.0])}, + use_grad=True, + use_hess=True, + use_hessp=True, + use_jax_gradients=use_jax_gradients, + compile_kwargs=dict(mode="JAX"), ) x_val = np.array([1.0, 2.0]) expected_z = sum(compute_z(x_val)) - z_jax = f_z(x_val) + z_jax, grad_val = f_loss(x_val) np.testing.assert_allclose(z_jax, expected_z) - - grad_val = np.array(f_grad(x_val)) np.testing.assert_allclose(grad_val.squeeze(), np.array([2 * x_val[0] + x_val[1], x_val[0]])) hess_val = np.array(f_hess(x_val)) @@ -64,7 +70,8 @@ def compute_z(x): ("trust-constr", True, True), ], ) -def test_JAX_map(method, use_grad, use_hess, rng): +@pytest.mark.parametrize("use_jax_gradients", [True, False], ids=["jax_grad", "pt_grad"]) +def test_JAX_map(method, use_grad, use_hess, use_jax_gradients, rng): extra_kwargs = {} if method == "dogleg": # HACK -- dogleg requires that the hessian of the objective function is PSD, so we have to pick a point @@ -77,7 +84,13 @@ def test_JAX_map(method, use_grad, use_hess, rng): pm.Normal("y_hat", mu=mu, sigma=sigma, observed=rng.normal(loc=3, scale=1.5, size=100)) optimized_point = find_MAP( - method=method, **extra_kwargs, use_grad=use_grad, use_hess=use_hess, progressbar=False + method=method, + **extra_kwargs, + use_grad=use_grad, + use_hess=use_hess, + progressbar=False, + use_jax_gradients=use_jax_gradients, + compile_kwargs={"mode": "JAX"}, ) mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"] @@ -90,7 +103,8 @@ def test_JAX_map(method, use_grad, use_hess, rng): [True, False], ids=["transformed", "untransformed"], ) -def test_fit_laplace_coords(rng, transform_samples): +@pytest.mark.parametrize("mode", ["JAX", None], ids=["jax", "pytensor"]) +def test_fit_laplace_coords(rng, transform_samples, mode): coords = {"city": ["A", "B", "C"], "obs_idx": np.arange(100)} with pm.Model(coords=coords) as model: mu = pm.Normal("mu", mu=3, sigma=0.5, dims=["city"]) @@ -104,9 +118,12 @@ def test_fit_laplace_coords(rng, transform_samples): ) optimized_point = find_MAP( - method="Newton-CG", + method="trust-ncg", use_grad=True, + use_hessp=True, progressbar=False, + compile_kwargs=dict(mode=mode), + use_jax_gradients=mode == "JAX", ) for value in optimized_point.values(): @@ -117,9 +134,10 @@ def test_fit_laplace_coords(rng, transform_samples): model, transform_samples=transform_samples, progressbar=False, + compile_kwargs=dict(mode=mode), ) - np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2, 3), 3), atol=0.3) + np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2, 3), 3), atol=0.5) np.testing.assert_allclose( np.mean(idata.posterior.sigma, axis=1), np.full((2, 3), 1.5), atol=0.3 ) From 4c2529d67290dd00be54cb104dffcc4b1ca084a2 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 4 Dec 2024 02:42:17 +0800 Subject: [PATCH 11/21] Reconcile the two laplace approximation functions --- pymc_experimental/inference/find_map.py | 277 +----------- pymc_experimental/inference/fit.py | 4 +- pymc_experimental/inference/laplace.py | 568 +++++++++++++++++++----- tests/test_jax_find_map.py | 111 ----- tests/test_laplace.py | 210 ++++++--- 5 files changed, 642 insertions(+), 528 deletions(-) diff --git a/pymc_experimental/inference/find_map.py b/pymc_experimental/inference/find_map.py index bcbcad23..95877b27 100644 --- a/pymc_experimental/inference/find_map.py +++ b/pymc_experimental/inference/find_map.py @@ -1,38 +1,41 @@ import logging from collections.abc import Callable -from typing import Literal, cast +from typing import cast -import arviz as az import jax import numpy as np import pymc as pm import pytensor import pytensor.tensor as pt -import xarray as xr -from arviz import dict_to_dataset from better_optimize import minimize -from better_optimize.constants import minimize_method -from pymc.backends.arviz import ( - coords_and_dims_for_inferencedata, - find_constants, - find_observations, -) +from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.initial_point import make_initial_point_fn -from pymc.model.transform.conditioning import remove_value_transforms from pymc.model.transform.optimization import freeze_dims_and_data from pymc.pytensorf import join_nonshared_inputs from pymc.util import get_default_varnames from pytensor.compile import Function from pytensor.tensor import TensorVariable -from scipy import stats from scipy.optimize import OptimizeResult _log = logging.getLogger(__name__) +def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp): + method_info = MINIMIZE_MODE_KWARGS[method].copy() + + use_grad = use_grad if use_grad is not None else method_info["uses_grad"] + use_hess = use_hess if use_hess is not None else method_info["uses_hess"] + use_hessp = use_hessp if use_hessp is not None else method_info["uses_hessp"] + + if use_hess and use_hessp: + use_hess = False + + return use_grad, use_hess, use_hessp + + def get_nearest_psd(A: np.ndarray) -> np.ndarray: """ Compute the nearest positive semi-definite matrix to a given matrix. @@ -60,7 +63,9 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray: def _unconstrained_vector_to_constrained_rvs(model): constrained_rvs, unconstrained_vector = join_nonshared_inputs( - model.initial_point(), inputs=model.value_vars, outputs=model.unobserved_value_vars + model.initial_point(), + inputs=model.value_vars, + outputs=get_default_varnames(model.unobserved_value_vars, include_transformed=False), ) unconstrained_vector.name = "unconstrained_vector" @@ -289,247 +294,6 @@ def scipy_optimize_funcs_from_loss( return f_loss, f_hess, f_hessp -def fit_mvn_to_MAP( - optimized_point: dict[str, np.ndarray], - model: pm.Model, - on_bad_cov: Literal["warn", "error", "ignore"] = "ignore", - transform_samples: bool = True, - use_jax_gradients: bool = False, - zero_tol: float = 1e-8, - diag_jitter: float | None = 1e-8, - compile_kwargs: dict | None = None, -) -> tuple[RaveledVars, np.ndarray]: - """ - Create a multivariate normal distribution using the inverse of the negative Hessian matrix of the log-posterior - evaluated at the MAP estimate. This is the basis of the Laplace approximation. - - Parameters - ---------- - optimized_point : dict[str, np.ndarray] - Local maximum a posteriori (MAP) point returned from pymc.find_MAP or jax_tools.fit_map - model : Model - A PyMC model - on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore' - What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite. - If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned. - If 'error', an error will be raised. - transform_samples : bool - Whether to transform the samples back to the original parameter space. Default is True. - zero_tol: float - Value below which an element of the Hessian matrix is counted as 0. - This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8. - diag_jitter: float | None - A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite. - If None, no jitter is added. Default is 1e-8. - - Returns - ------- - map_estimate: RaveledVars - The MAP estimate of the model parameters, raveled into a 1D array. - - inverse_hessian: np.ndarray - The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate. - """ - compile_kwargs = {} if compile_kwargs is None else compile_kwargs - frozen_model = freeze_dims_and_data(model) - - if not transform_samples: - untransformed_model = remove_value_transforms(frozen_model) - logp = untransformed_model.logp(jacobian=False) - variables = untransformed_model.continuous_value_vars - else: - logp = frozen_model.logp(jacobian=True) - variables = frozen_model.continuous_value_vars - - variable_names = {var.name for var in variables} - optimized_free_params = {k: v for k, v in optimized_point.items() if k in variable_names} - mu = DictToArrayBijection.map(optimized_free_params) - - _, f_hess, _ = scipy_optimize_funcs_from_loss( - loss=-logp, - inputs=variables, - initial_point_dict=frozen_model.initial_point(), - use_grad=True, - use_hess=True, - use_hessp=False, - use_jax_gradients=use_jax_gradients, - compile_kwargs=compile_kwargs, - ) - - H = -f_hess(mu.data) - H_inv = np.linalg.pinv(np.where(np.abs(H) < zero_tol, 0, -H)) - - def stabilize(x, jitter): - return x + np.eye(x.shape[0]) * jitter - - H_inv = H_inv if diag_jitter is None else stabilize(H_inv, diag_jitter) - - try: - np.linalg.cholesky(H_inv) - except np.linalg.LinAlgError: - if on_bad_cov == "error": - raise np.linalg.LinAlgError( - "Inverse Hessian not positive-semi definite at the provided point" - ) - H_inv = get_nearest_psd(H_inv) - if on_bad_cov == "warn": - _log.warning( - "Inverse Hessian is not positive semi-definite at the provided point, using the closest PSD " - "matrix in L1-norm instead" - ) - - return mu, H_inv - - -def laplace( - mu: RaveledVars, - H_inv: np.ndarray, - model: pm.Model, - chains: int = 2, - draws: int = 500, - transform_samples: bool = True, - progressbar: bool = True, - **compile_kwargs, -) -> az.InferenceData: - """ - - Parameters - ---------- - mu - H_inv - model : Model - A PyMC model - chains : int - The number of sampling chains running in parallel. Default is 2. - draws : int - The number of samples to draw from the approximated posterior. Default is 500. - transform_samples : bool - Whether to transform the samples back to the original parameter space. Default is True. - - Returns - ------- - idata: az.InferenceData - An InferenceData object containing the approximated posterior samples. - """ - posterior_dist = stats.multivariate_normal(mean=mu.data, cov=H_inv, allow_singular=True) - posterior_draws = posterior_dist.rvs(size=(chains, draws)) - - if transform_samples: - constrained_rvs, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(model) - batched_values = pt.tensor( - "batched_values", - shape=(chains, draws, *unconstrained_vector.type.shape), - dtype=unconstrained_vector.type.dtype, - ) - batched_rvs = pytensor.graph.vectorize_graph( - constrained_rvs, replace={unconstrained_vector: batched_values} - ) - - f_constrain = pm.compile_pymc( - inputs=[batched_values], outputs=batched_rvs, **compile_kwargs - ) - posterior_draws = f_constrain(posterior_draws) - - else: - info = mu.point_map_info - flat_shapes = [np.prod(shape).astype(int) for _, shape, _ in info] - slices = [ - slice(sum(flat_shapes[:i]), sum(flat_shapes[: i + 1])) for i in range(len(flat_shapes)) - ] - - posterior_draws = [ - posterior_draws[..., idx].reshape((chains, draws, *shape)).astype(dtype) - for idx, (name, shape, dtype) in zip(slices, info) - ] - - def make_rv_coords(name): - coords = {"chain": range(chains), "draw": range(draws)} - extra_dims = model.named_vars_to_dims.get(name) - if extra_dims is None: - return coords - return coords | {dim: list(model.coords[dim]) for dim in extra_dims} - - def make_rv_dims(name): - dims = ["chain", "draw"] - extra_dims = model.named_vars_to_dims.get(name) - if extra_dims is None: - return dims - return dims + list(extra_dims) - - idata = { - name: xr.DataArray( - data=draws.squeeze(), - coords=make_rv_coords(name), - dims=make_rv_dims(name), - name=name, - ) - for (name, _, _), draws in zip(mu.point_map_info, posterior_draws) - } - - coords, dims = coords_and_dims_for_inferencedata(model) - idata = az.convert_to_inference_data(idata, coords=coords, dims=dims) - - if model.deterministics: - idata.posterior = pm.compute_deterministics( - idata.posterior, - model=model, - merge_dataset=True, - progressbar=progressbar, - compile_kwargs=compile_kwargs, - ) - - observed_data = dict_to_dataset( - find_observations(model), - library=pm, - coords=coords, - dims=dims, - default_dims=[], - ) - - constant_data = dict_to_dataset( - find_constants(model), - library=pm, - coords=coords, - dims=dims, - default_dims=[], - ) - - idata.add_groups( - {"observed_data": observed_data, "constant_data": constant_data}, - coords=coords, - dims=dims, - ) - - return idata - - -def fit_laplace( - optimized_point: dict[str, np.ndarray], - model: pm.Model, - chains: int = 2, - draws: int = 500, - on_bad_cov: Literal["warn", "error", "ignore"] = "ignore", - transform_samples: bool = True, - zero_tol: float = 1e-8, - diag_jitter: float | None = 1e-8, - progressbar: bool = True, - compile_kwargs: dict | None = None, -) -> az.InferenceData: - compile_kwargs = {} if compile_kwargs is None else compile_kwargs - - mu, H_inv = fit_mvn_to_MAP( - optimized_point=optimized_point, - model=model, - on_bad_cov=on_bad_cov, - transform_samples=transform_samples, - zero_tol=zero_tol, - diag_jitter=diag_jitter, - compile_kwargs=compile_kwargs, - ) - - return laplace(mu, H_inv, model, chains, draws, transform_samples, progressbar) - - def find_MAP( method: minimize_method, *, @@ -605,9 +369,12 @@ def find_MAP( initial_params = DictToArrayBijection.map( {var_name: value for var_name, value in start_dict.items() if var_name in vars_dict} ) + use_grad, use_hess, use_hessp = set_optimizer_function_defaults( + method, use_grad, use_hess, use_hessp + ) f_logp, f_hess, f_hessp = scipy_optimize_funcs_from_loss( - loss=-frozen_model.logp(), + loss=-frozen_model.logp(jacobian=False), inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars, initial_point_dict=start_dict, use_grad=use_grad, diff --git a/pymc_experimental/inference/fit.py b/pymc_experimental/inference/fit.py index f6c87d90..7897aeed 100644 --- a/pymc_experimental/inference/fit.py +++ b/pymc_experimental/inference/fit.py @@ -39,6 +39,6 @@ def fit(method, **kwargs): return fit_pathfinder(**kwargs) if method == "laplace": - from pymc_experimental.inference.laplace import laplace + from pymc_experimental.inference.laplace import fit_laplace - return laplace(**kwargs) + return fit_laplace(**kwargs) diff --git a/pymc_experimental/inference/laplace.py b/pymc_experimental/inference/laplace.py index 7d7beb59..80fd7873 100644 --- a/pymc_experimental/inference/laplace.py +++ b/pymc_experimental/inference/laplace.py @@ -12,156 +12,192 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings -from collections.abc import Sequence +from functools import reduce +from itertools import product +from typing import Literal import arviz as az import numpy as np import pymc as pm +import pytensor +import pytensor.tensor as pt import xarray as xr from arviz import dict_to_dataset +from better_optimize.constants import minimize_method +from inference.find_map import ( + _log, + _unconstrained_vector_to_constrained_rvs, + find_MAP, + get_nearest_psd, + scipy_optimize_funcs_from_loss, +) +from pymc import DictToArrayBijection from pymc.backends.arviz import ( coords_and_dims_for_inferencedata, find_constants, find_observations, ) +from pymc.blocking import RaveledVars from pymc.model.transform.conditioning import remove_value_transforms -from pymc.util import RandomSeed -from pytensor import Variable +from pymc.model.transform.optimization import freeze_dims_and_data +from pymc.util import get_default_varnames +from scipy import stats -def laplace( - vars: Sequence[Variable], - draws: int | None = 1000, - model=None, - random_seed: RandomSeed | None = None, - progressbar=True, -): +def laplace_draws_to_inferencedata( + posterior_draws: list[np.ndarray[float | int]], model: pm.Model | None = None +) -> az.InferenceData: """ - Create a Laplace (quadratic) approximation for a posterior distribution. + Convert draws from a posterior estimated with the Laplace approximation to an InferenceData object. - This function generates a Laplace approximation for a given posterior distribution using a specified - number of draws. This is useful for obtaining a parametric approximation to the posterior distribution - that can be used for further analysis. Parameters ---------- - vars : Sequence[Variable] - A sequence of variables for which the Laplace approximation of the posterior distribution - is to be created. - draws : Optional[int] with default=1_000 - The number of draws to sample from the posterior distribution for creating the approximation. - For draws=None only the fit of the Laplace approximation is returned - model : object, optional, default=None - The model object that defines the posterior distribution. If None, the default model will be used. - random_seed : Optional[RandomSeed], optional, default=None - An optional random seed to ensure reproducibility of the draws. If None, the draws will be - generated using the current random state. - progressbar: bool, optional defaults to True - Whether to display a progress bar in the command line. + posterior_draws: list of np.ndarray + A list of arrays containing the posterior draws. Each array should have shape (chains, draws, *shape), where + shape is the shape of the variable in the posterior. + model: Model, optional + A PyMC model. If None, the model is taken from the current model context. Returns ------- - arviz.InferenceData - An `InferenceData` object from the `arviz` library containing the Laplace - approximation of the posterior distribution. The inferenceData object also - contains constant and observed data as well as deterministic variables. - InferenceData also contains a group 'fit' with the mean and covariance - for the Laplace approximation. - - Examples - -------- - >>> import numpy as np - >>> import pymc as pm - >>> import arviz as az - >>> from pymc_experimental.inference.laplace import laplace - >>> y = np.array([2642, 3503, 4358]*10) - >>> with pm.Model() as m: - >>> logsigma = pm.Uniform("logsigma", 1, 100) - >>> mu = pm.Uniform("mu", -10000, 10000) - >>> yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) - >>> idata = laplace([mu, logsigma], model=m) - - Notes - ----- - This method of approximation may not be suitable for all types of posterior distributions, - especially those with significant skewness or multimodality. - - See Also - -------- - fit : Calling the inference function 'fit' like pmx.fit(method="laplace", vars=[mu, logsigma], model=m) - will forward the call to 'laplace'. - + idata: az.InferenceData + An InferenceData object containing the approximated posterior samples """ - - rng = np.random.default_rng(seed=random_seed) - - transformed_m = pm.modelcontext(model) - - if len(vars) != len(transformed_m.free_RVs): - warnings.warn( - "Number of variables in vars does not eqaul the number of variables in the model.", - UserWarning, + model = pm.modelcontext(model) + chains, draws, *_ = posterior_draws[0].shape + + def make_rv_coords(name): + coords = {"chain": range(chains), "draw": range(draws)} + extra_dims = model.named_vars_to_dims.get(name) + if extra_dims is None: + return coords + return coords | {dim: list(model.coords[dim]) for dim in extra_dims} + + def make_rv_dims(name): + dims = ["chain", "draw"] + extra_dims = model.named_vars_to_dims.get(name) + if extra_dims is None: + return dims + return dims + list(extra_dims) + + names = [ + x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=False) + ] + idata = { + name: xr.DataArray( + data=draws, + coords=make_rv_coords(name), + dims=make_rv_dims(name), + name=name, ) + for name, draws in zip(names, posterior_draws) + } - map = pm.find_MAP(vars=vars, progressbar=progressbar, model=transformed_m) + coords, dims = coords_and_dims_for_inferencedata(model) + idata = az.convert_to_inference_data(idata, coords=coords, dims=dims) - # See https://www.pymc.io/projects/docs/en/stable/api/model/generated/pymc.model.transform.conditioning.remove_value_transforms.html - untransformed_m = remove_value_transforms(transformed_m) - untransformed_vars = [untransformed_m[v.name] for v in vars] - hessian = pm.find_hessian(point=map, vars=untransformed_vars, model=untransformed_m) + return idata - if np.linalg.det(hessian) == 0: - raise np.linalg.LinAlgError("Hessian is singular.") - cov = np.linalg.inv(hessian) - mean = np.concatenate([np.atleast_1d(map[v.name]) for v in vars]) +def add_fit_to_inferencedata( + idata: az.InferenceData, mu: RaveledVars, H_inv: np.ndarray, model: pm.Model | None = None +) -> az.InferenceData: + """ + Add the mean vector and covariance matrix of the Laplace approximation to an InferenceData object. - chains = 1 - if draws is not None: - samples = rng.multivariate_normal(mean, cov, size=(chains, draws)) + Parameters + ---------- + idata: az.InfereceData + An InferenceData object containing the approximated posterior samples. + mu: RaveledVars + The MAP estimate of the model parameters. + H_inv: np.ndarray + The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate. + model: Model, optional + A PyMC model. If None, the model is taken from the current model context. - data_vars = {} - for i, var in enumerate(vars): - data_vars[str(var)] = xr.DataArray(samples[:, :, i], dims=("chain", "draw")) + Returns + ------- + idata: az.InferenceData + The provided InferenceData, with the mean vector and covariance matrix added to the "fit" group. + """ + model = pm.modelcontext(model) + coords = model.coords - coords = {"chain": np.arange(chains), "draw": np.arange(draws)} - ds = xr.Dataset(data_vars, coords=coords) + variable_names, *_ = zip(*mu.point_map_info) - idata = az.convert_to_inference_data(ds) - idata = addDataToInferenceData(model, idata, progressbar) - else: - idata = az.InferenceData() + def make_unpacked_variable_names(name): + value_to_dim = { + x.name: model.named_vars_to_dims.get(model.values_to_rvs[x].name, None) + for x in model.value_vars + } + value_to_dim = {k: v for k, v in value_to_dim.items() if v is not None} - idata = addFitToInferenceData(vars, idata, mean, cov) + rv_to_dim = model.named_vars_to_dims + dims_dict = rv_to_dim | value_to_dim - return idata + dims = dims_dict.get(name) + if dims is None: + return [name] + labels = product(*(coords[dim] for dim in dims)) + return [f"{name}[{','.join(map(str, label))}]" for label in labels] + unpacked_variable_names = reduce( + lambda lst, name: lst + make_unpacked_variable_names(name), variable_names, [] + ) -def addFitToInferenceData(vars, idata, mean, covariance): - coord_names = [v.name for v in vars] - # Convert to xarray DataArray - mean_dataarray = xr.DataArray(mean, dims=["rows"], coords={"rows": coord_names}) + mean_dataarray = xr.DataArray(mu.data, dims=["rows"], coords={"rows": unpacked_variable_names}) cov_dataarray = xr.DataArray( - covariance, dims=["rows", "columns"], coords={"rows": coord_names, "columns": coord_names} + H_inv, + dims=["rows", "columns"], + coords={"rows": unpacked_variable_names, "columns": unpacked_variable_names}, ) - # Create xarray dataset dataset = xr.Dataset({"mean_vector": mean_dataarray, "covariance_matrix": cov_dataarray}) - idata.add_groups(fit=dataset) return idata -def addDataToInferenceData(model, trace, progressbar): - # Add deterministic variables to inference data - trace.posterior = pm.compute_deterministics( - trace.posterior, model=model, merge_dataset=True, progressbar=progressbar - ) +def add_data_to_inferencedata( + idata: az.InferenceData, + progressbar: bool = True, + model: pm.Model | None = None, + compile_kwargs: dict | None = None, +) -> az.InferenceData: + """ + Add observed and constant data to an InferenceData object. + + Parameters + ---------- + idata: az.InferenceData + An InferenceData object containing the approximated posterior samples. + progressbar: bool + Whether to display a progress bar during computations. Default is True. + model: Model, optional + A PyMC model. If None, the model is taken from the current model context. + compile_kwargs: dict, optional + Additional keyword arguments to pass to pytensor.function. + + Returns + ------- + idata: az.InferenceData + The provided InferenceData, with observed and constant data added. + """ + model = pm.modelcontext(model) + + if model.deterministics: + idata.posterior = pm.compute_deterministics( + idata.posterior, + model=model, + merge_dataset=True, + progressbar=progressbar, + compile_kwargs=compile_kwargs, + ) coords, dims = coords_and_dims_for_inferencedata(model) @@ -181,10 +217,334 @@ def addDataToInferenceData(model, trace, progressbar): default_dims=[], ) - trace.add_groups( + idata.add_groups( {"observed_data": observed_data, "constant_data": constant_data}, coords=coords, dims=dims, ) - return trace + return idata + + +def fit_mvn_to_MAP( + optimized_point: dict[str, np.ndarray], + model: pm.Model | None = None, + on_bad_cov: Literal["warn", "error", "ignore"] = "ignore", + transform_samples: bool = False, + use_jax_gradients: bool = False, + zero_tol: float = 1e-8, + diag_jitter: float | None = 1e-8, + compile_kwargs: dict | None = None, +) -> tuple[RaveledVars, np.ndarray]: + """ + Create a multivariate normal distribution using the inverse of the negative Hessian matrix of the log-posterior + evaluated at the MAP estimate. This is the basis of the Laplace approximation. + + Parameters + ---------- + optimized_point : dict[str, np.ndarray] + Local maximum a posteriori (MAP) point returned from pymc.find_MAP or jax_tools.fit_map + model : Model, optional + A PyMC model. If None, the model is taken from the current model context. + on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore' + What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite. + If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned. + If 'error', an error will be raised. + transform_samples : bool + Whether to transform the samples back to the original parameter space. Default is True. + zero_tol: float + Value below which an element of the Hessian matrix is counted as 0. + This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8. + diag_jitter: float | None + A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite. + If None, no jitter is added. Default is 1e-8. + + Returns + ------- + map_estimate: RaveledVars + The MAP estimate of the model parameters, raveled into a 1D array. + + inverse_hessian: np.ndarray + The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate. + """ + model = pm.modelcontext(model) + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + frozen_model = freeze_dims_and_data(model) + + if not transform_samples: + untransformed_model = remove_value_transforms(frozen_model) + logp = untransformed_model.logp(jacobian=False) + variables = untransformed_model.continuous_value_vars + else: + logp = frozen_model.logp(jacobian=True) + variables = frozen_model.continuous_value_vars + + variable_names = {var.name for var in variables} + optimized_free_params = {k: v for k, v in optimized_point.items() if k in variable_names} + mu = DictToArrayBijection.map(optimized_free_params) + + _, f_hess, _ = scipy_optimize_funcs_from_loss( + loss=-logp, + inputs=variables, + initial_point_dict=optimized_free_params, + use_grad=True, + use_hess=True, + use_hessp=False, + use_jax_gradients=use_jax_gradients, + compile_kwargs=compile_kwargs, + ) + + H = -f_hess(mu.data) + H_inv = np.linalg.pinv(np.where(np.abs(H) < zero_tol, 0, -H)) + + def stabilize(x, jitter): + return x + np.eye(x.shape[0]) * jitter + + H_inv = H_inv if diag_jitter is None else stabilize(H_inv, diag_jitter) + + try: + np.linalg.cholesky(H_inv) + except np.linalg.LinAlgError: + if on_bad_cov == "error": + raise np.linalg.LinAlgError( + "Inverse Hessian not positive-semi definite at the provided point" + ) + H_inv = get_nearest_psd(H_inv) + if on_bad_cov == "warn": + _log.warning( + "Inverse Hessian is not positive semi-definite at the provided point, using the closest PSD " + "matrix in L1-norm instead" + ) + + return mu, H_inv + + +def laplace( + mu: RaveledVars, + H_inv: np.ndarray, + model: pm.Model | None = None, + chains: int = 2, + draws: int = 500, + transform_samples: bool = False, + progressbar: bool = True, + random_seed: int | np.random.Generator | None = None, + compile_kwargs: dict | None = None, +) -> az.InferenceData: + """ + Generate samples from a multivariate normal distribution with mean `mu` and inverse covariance matrix `H_inv`. + + Parameters + ---------- + mu + H_inv + model : Model + A PyMC model + chains : int + The number of sampling chains running in parallel. Default is 2. + draws : int + The number of samples to draw from the approximated posterior. Default is 500. + transform_samples : bool + Whether to transform the samples back to the original parameter space. Default is True. + progressbar : bool + Whether to display a progress bar during computations. Default is True. + random_seed: int | np.random.Generator | None + Seed for the random number generator or a numpy Generator for reproducibility + + Returns + ------- + idata: az.InferenceData + An InferenceData object containing the approximated posterior samples. + """ + model = pm.modelcontext(model) + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + rng = np.random.default_rng(random_seed) + + posterior_dist = stats.multivariate_normal( + mean=mu.data, cov=H_inv, allow_singular=True, seed=rng + ) + posterior_draws = posterior_dist.rvs(size=(chains, draws)) + + if transform_samples: + constrained_rvs, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(model) + batched_values = pt.tensor( + "batched_values", + shape=(chains, draws, *unconstrained_vector.type.shape), + dtype=unconstrained_vector.type.dtype, + ) + batched_rvs = pytensor.graph.vectorize_graph( + constrained_rvs, replace={unconstrained_vector: batched_values} + ) + + f_constrain = pm.compile_pymc( + inputs=[batched_values], outputs=batched_rvs, **compile_kwargs + ) + posterior_draws = f_constrain(posterior_draws) + + else: + info = mu.point_map_info + flat_shapes = [np.prod(shape).astype(int) for _, shape, _ in info] + slices = [ + slice(sum(flat_shapes[:i]), sum(flat_shapes[: i + 1])) for i in range(len(flat_shapes)) + ] + + posterior_draws = [ + posterior_draws[..., idx].reshape((chains, draws, *shape)).astype(dtype) + for idx, (name, shape, dtype) in zip(slices, info) + ] + + idata = laplace_draws_to_inferencedata(posterior_draws, model) + idata = add_fit_to_inferencedata(idata, mu, H_inv) + idata = add_data_to_inferencedata(idata, progressbar, model, compile_kwargs) + + return idata + + +def fit_laplace( + optimize_method: minimize_method = "BFGS", + *, + model: pm.Model | None = None, + use_grad: bool | None = None, + use_hessp: bool | None = None, + use_hess: bool | None = None, + initvals: dict | None = None, + random_seed: int | np.random.Generator | None = None, + return_raw: bool = False, + jitter_rvs: list[pt.TensorVariable] | None = None, + progressbar: bool = True, + include_transformed: bool = True, + use_jax_gradients: bool = False, + chains: int = 2, + draws: int = 500, + on_bad_cov: Literal["warn", "error", "ignore"] = "ignore", + transform_samples: bool = False, + zero_tol: float = 1e-8, + diag_jitter: float | None = 1e-8, + optimizer_kwargs: dict | None = None, + compile_kwargs: dict | None = None, +) -> az.InferenceData: + """ + Create a Laplace (quadratic) approximation for a posterior distribution. + + This function generates a Laplace approximation for a given posterior distribution using a specified + number of draws. This is useful for obtaining a parametric approximation to the posterior distribution + that can be used for further analysis. + + Parameters + ---------- + model : pm.Model + The PyMC model to be fit. If None, the current model context is used. + optimize_method : str + The optimization method to use. See scipy.optimize.minimize documentation for details. + use_grad : bool | None, optional + Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on + the ``method``. + use_hessp : bool | None, optional + Whether to use Hessian-vector products in the optimization. Defaults to None, which determines this automatically based on + the ``method``. + use_hess : bool | None, optional + Whether to use the Hessian matrix in the optimization. Defaults to None, which determines this automatically based on + the ``method``. + initvals : None | dict, optional + Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted. + If None, the model's default initial values are used. + random_seed : None | int | np.random.Generator, optional + Seed for the random number generator or a numpy Generator for reproducibility + return_raw: bool | False, optinal + Whether to also return the full output of `scipy.optimize.minimize` + jitter_rvs : list of TensorVariables, optional + Variables whose initial values should be jittered. If None, all variables are jittered. + progressbar : bool, optional + Whether to display a progress bar during optimization. Defaults to True. + include_transformed: bool, optional + Whether to include transformed variable values in the returned dictionary. Defaults to True. + use_jax_gradients: bool, optional + Whether to use JAX for gradient calculations. Defaults to False. + chains: int, default: 2 + The number of sampling chains running in parallel. + draws: int, default: 500 + The number of samples to draw from the approximated posterior. + on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore' + What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite. + If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned. + If 'error', an error will be raised. + transform_samples : bool + Whether to transform the samples back to the original parameter space. Default is True. + zero_tol: float + Value below which an element of the Hessian matrix is counted as 0. + This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8. + diag_jitter: float | None + A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite. + If None, no jitter is added. Default is 1e-8. + compile_kwargs: optional + Additional keyword arguments to pass to pytensor.function. + + Returns + ------- + idata: az.InferenceData + An InferenceData object containing the approximated posterior samples. + + Examples + -------- + >>> from inference.laplace import fit_laplace + >>> import numpy as np + >>> import pymc as pm + >>> import arviz as az + >>> y = np.array([2642, 3503, 4358]*10) + >>> with pm.Model() as m: + >>> logsigma = pm.Uniform("logsigma", 1, 100) + >>> mu = pm.Uniform("mu", -10000, 10000) + >>> yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) + >>> idata = fit_laplace() + + Notes + ----- + This method of approximation may not be suitable for all types of posterior distributions, + especially those with significant skewness or multimodality. + + See Also + -------- + fit : Calling the inference function 'fit' like pmx.fit(method="laplace", model=m) + will forward the call to 'fit_laplace'. + + """ + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs + + optimized_point = find_MAP( + method=optimize_method, + model=model, + use_grad=use_grad, + use_hessp=use_hessp, + use_hess=use_hess, + initvals=initvals, + random_seed=random_seed, + return_raw=return_raw, + jitter_rvs=jitter_rvs, + progressbar=progressbar, + include_transformed=include_transformed, + use_jax_gradients=use_jax_gradients, + compile_kwargs=compile_kwargs, + **optimizer_kwargs, + ) + + mu, H_inv = fit_mvn_to_MAP( + optimized_point=optimized_point, + model=model, + on_bad_cov=on_bad_cov, + transform_samples=transform_samples, + zero_tol=zero_tol, + diag_jitter=diag_jitter, + compile_kwargs=compile_kwargs, + ) + + return laplace( + mu=mu, + H_inv=H_inv, + model=model, + chains=chains, + draws=draws, + transform_samples=transform_samples, + progressbar=progressbar, + random_seed=random_seed, + compile_kwargs=compile_kwargs, + ) diff --git a/tests/test_jax_find_map.py b/tests/test_jax_find_map.py index 01e05301..cba9c503 100644 --- a/tests/test_jax_find_map.py +++ b/tests/test_jax_find_map.py @@ -5,7 +5,6 @@ from pymc_experimental.inference.find_map import ( find_MAP, - fit_laplace, scipy_optimize_funcs_from_loss, ) @@ -96,113 +95,3 @@ def test_JAX_map(method, use_grad, use_hess, use_jax_gradients, rng): assert np.isclose(mu_hat, 3, atol=0.5) assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5) - - -@pytest.mark.parametrize( - "transform_samples", - [True, False], - ids=["transformed", "untransformed"], -) -@pytest.mark.parametrize("mode", ["JAX", None], ids=["jax", "pytensor"]) -def test_fit_laplace_coords(rng, transform_samples, mode): - coords = {"city": ["A", "B", "C"], "obs_idx": np.arange(100)} - with pm.Model(coords=coords) as model: - mu = pm.Normal("mu", mu=3, sigma=0.5, dims=["city"]) - sigma = pm.Normal("sigma", mu=1.5, sigma=0.5, dims=["city"]) - obs = pm.Normal( - "obs", - mu=mu, - sigma=sigma, - observed=rng.normal(loc=3, scale=1.5, size=(100, 3)), - dims=["obs_idx", "city"], - ) - - optimized_point = find_MAP( - method="trust-ncg", - use_grad=True, - use_hessp=True, - progressbar=False, - compile_kwargs=dict(mode=mode), - use_jax_gradients=mode == "JAX", - ) - - for value in optimized_point.values(): - assert value.shape == (3,) - - idata = fit_laplace( - optimized_point, - model, - transform_samples=transform_samples, - progressbar=False, - compile_kwargs=dict(mode=mode), - ) - - np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2, 3), 3), atol=0.5) - np.testing.assert_allclose( - np.mean(idata.posterior.sigma, axis=1), np.full((2, 3), 1.5), atol=0.3 - ) - - -def test_fit_laplace_ragged_coords(rng): - coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)} - with pm.Model(coords=coords) as ragged_dim_model: - X = pm.Data("X", np.ones((100, 2)), dims=["obs_idx", "feature"]) - beta = pm.Normal( - "beta", mu=[[-100.0, 100.0], [-100.0, 100.0], [-100.0, 100.0]], dims=["city", "feature"] - ) - mu = pm.Deterministic( - "mu", (X[:, None, :] * beta[None]).sum(axis=-1), dims=["obs_idx", "feature"] - ) - sigma = pm.Normal("sigma", mu=1.5, sigma=0.5, dims=["city"]) - - obs = pm.Normal( - "obs", - mu=mu, - sigma=sigma, - observed=rng.normal(loc=3, scale=1.5, size=(100, 3)), - dims=["obs_idx", "city"], - ) - - optimized_point, _ = find_MAP( - method="Newton-CG", use_grad=True, progressbar=False, return_raw=True - ) - - idata = fit_laplace(optimized_point, ragged_dim_model, progressbar=False) - - assert idata["posterior"].beta.shape[-2:] == (3, 2) - assert idata["posterior"].sigma.shape[-1:] == (3,) - - # Check that everything got unraveled correctly -- feature 0 should be strictly negative, feature 1 - # strictly positive - assert (idata["posterior"].beta.sel(feature=0).to_numpy() < 0).all() - assert (idata["posterior"].beta.sel(feature=1).to_numpy() > 0).all() - - -@pytest.mark.parametrize( - "transform_samples", - [True, False], - ids=["transformed", "untransformed"], -) -def test_fit_laplace(transform_samples): - with pm.Model() as simp_model: - mu = pm.Normal("mu", mu=3, sigma=0.5) - sigma = pm.Normal("sigma", mu=1.5, sigma=0.5) - obs = pm.Normal( - "obs", - mu=mu, - sigma=sigma, - observed=np.random.default_rng().normal(loc=3, scale=1.5, size=(10000,)), - ) - - optimized_point = find_MAP( - method="Newton-CG", - use_grad=True, - progressbar=False, - ) - - idata = fit_laplace( - optimized_point, simp_model, transform_samples=transform_samples, progressbar=False - ) - - np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2,), 3), atol=0.1) - np.testing.assert_allclose(np.mean(idata.posterior.sigma, axis=1), np.full((2,), 1.5), atol=0.1) diff --git a/tests/test_laplace.py b/tests/test_laplace.py index 3fefe3f7..424e41e4 100644 --- a/tests/test_laplace.py +++ b/tests/test_laplace.py @@ -19,6 +19,15 @@ import pymc_experimental as pmx +from pymc_experimental.inference.find_map import find_MAP +from pymc_experimental.inference.laplace import fit_laplace, fit_mvn_to_MAP, laplace + + +@pytest.fixture(scope="session") +def rng(): + seed = sum(map(ord, "test_laplace")) + return np.random.default_rng(seed) + @pytest.mark.filterwarnings( "ignore:hessian will stop negating the output in a future version of PyMC.\n" @@ -35,18 +44,15 @@ def test_laplace(): draws = 100000 with pm.Model() as m: - logsigma = pm.Uniform("logsigma", 1, 100) mu = pm.Uniform("mu", -10000, 10000) + logsigma = pm.Uniform("logsigma", 1, 100) + yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) vars = [mu, logsigma] - idata = pmx.fit( - method="laplace", - vars=vars, - model=m, - draws=draws, - random_seed=173300, - ) + idata = pmx.fit( + method="laplace", optimize_method="trust-ncg", draws=draws, random_seed=173300, chains=1 + ) assert idata.posterior["mu"].shape == (1, draws) assert idata.posterior["logsigma"].shape == (1, draws) @@ -57,14 +63,10 @@ def test_laplace(): bda_map = [y.mean(), np.log(y.std())] bda_cov = np.array([[y.var() / n, 0], [0, 1 / (2 * n)]]) - assert np.allclose(idata.fit["mean_vector"].values, bda_map) - assert np.allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4) + np.testing.assert_allclose(idata.fit["mean_vector"].values, bda_map) + np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4) -@pytest.mark.filterwarnings( - "ignore:hessian will stop negating the output in a future version of PyMC.\n" - + "To suppress this warning set `negate_output=False`:FutureWarning", -) def test_laplace_only_fit(): # Example originates from Bayesian Data Analyses, 3rd Edition # By Andrew Gelman, John Carlin, Hal Stern, David Dunson, @@ -80,55 +82,151 @@ def test_laplace_only_fit(): yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) vars = [mu, logsigma] - idata = pmx.fit( - method="laplace", - vars=vars, - draws=None, - model=m, - random_seed=173300, - ) + idata = pmx.fit( + method="laplace", + optimize_method="BFGS", + progressbar=True, + use_jax_gradients=True, + compile_kwargs={"mode": "JAX"}, + optimizer_kwargs=dict(maxiter=100_000, gtol=1e-100), + random_seed=173300, + ) assert idata.fit["mean_vector"].shape == (len(vars),) assert idata.fit["covariance_matrix"].shape == (len(vars), len(vars)) - bda_map = [y.mean(), np.log(y.std())] - bda_cov = np.array([[y.var() / n, 0], [0, 1 / (2 * n)]]) + bda_map = [np.log(y.std()), y.mean()] + bda_cov = np.array([[1 / (2 * n), 0], [0, y.var() / n]]) - assert np.allclose(idata.fit["mean_vector"].values, bda_map) - assert np.allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4) + np.testing.assert_allclose(idata.fit["mean_vector"].values, bda_map) + np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4) -@pytest.mark.filterwarnings( - "ignore:hessian will stop negating the output in a future version of PyMC.\n" - + "To suppress this warning set `negate_output=False`:FutureWarning", +@pytest.mark.parametrize( + "transform_samples", + [True, False], + ids=["transformed", "untransformed"], ) -def test_laplace_subset_of_rv(recwarn): - # Example originates from Bayesian Data Analyses, 3rd Edition - # By Andrew Gelman, John Carlin, Hal Stern, David Dunson, - # Aki Vehtari, and Donald Rubin. - # See section. 4.1 - - y = np.array([2642, 3503, 4358], dtype=np.float64) - n = y.size - - with pm.Model() as m: - logsigma = pm.Uniform("logsigma", 1, 100) - mu = pm.Uniform("mu", -10000, 10000) - yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) - vars = [mu] - - idata = pmx.fit( - method="laplace", - vars=vars, - draws=None, - model=m, - random_seed=173300, +@pytest.mark.parametrize("mode", ["JAX", None], ids=["jax", "pytensor"]) +def test_fit_laplace_coords(rng, transform_samples, mode): + coords = {"city": ["A", "B", "C"], "obs_idx": np.arange(100)} + with pm.Model(coords=coords) as model: + mu = pm.Normal("mu", mu=3, sigma=0.5, dims=["city"]) + sigma = pm.Exponential("sigma", 1, dims=["city"]) + obs = pm.Normal( + "obs", + mu=mu, + sigma=sigma, + observed=rng.normal(loc=3, scale=1.5, size=(100, 3)), + dims=["obs_idx", "city"], + ) + + optimized_point = find_MAP( + method="trust-ncg", + use_grad=True, + use_hessp=True, + progressbar=False, + compile_kwargs=dict(mode=mode), + use_jax_gradients=mode == "JAX", + ) + + for value in optimized_point.values(): + assert value.shape == (3,) + + mu, H_inv = fit_mvn_to_MAP( + optimized_point=optimized_point, + model=model, + transform_samples=transform_samples, + ) + + idata = laplace(mu=mu, H_inv=H_inv, model=model, transform_samples=transform_samples) + + np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2, 3), 3), atol=0.5) + np.testing.assert_allclose( + np.mean(idata.posterior.sigma, axis=1), np.full((2, 3), 1.5), atol=0.3 ) - assert len(recwarn) == 3 - w = recwarn.pop(UserWarning) - assert issubclass(w.category, UserWarning) - assert ( - str(w.message) - == "Number of variables in vars does not eqaul the number of variables in the model." - ) + suffix = "_log__" if transform_samples else "" + assert idata.fit.rows.values.tolist() == [ + "mu[A]", + "mu[B]", + "mu[C]", + f"sigma{suffix}[A]", + f"sigma{suffix}[B]", + f"sigma{suffix}[C]", + ] + + +def test_fit_laplace_ragged_coords(rng): + coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)} + with pm.Model(coords=coords) as ragged_dim_model: + X = pm.Data("X", np.ones((100, 2)), dims=["obs_idx", "feature"]) + beta = pm.Normal( + "beta", mu=[[-100.0, 100.0], [-100.0, 100.0], [-100.0, 100.0]], dims=["city", "feature"] + ) + mu = pm.Deterministic( + "mu", (X[:, None, :] * beta[None]).sum(axis=-1), dims=["obs_idx", "city"] + ) + sigma = pm.Normal("sigma", mu=1.5, sigma=0.5, dims=["city"]) + + obs = pm.Normal( + "obs", + mu=mu, + sigma=sigma, + observed=rng.normal(loc=3, scale=1.5, size=(100, 3)), + dims=["obs_idx", "city"], + ) + + idata = fit_laplace( + optimize_method="Newton-CG", + progressbar=False, + use_grad=True, + use_hessp=True, + use_jax_gradients=True, + compile_kwargs={"mode": "JAX"}, + ) + + assert idata["posterior"].beta.shape[-2:] == (3, 2) + assert idata["posterior"].sigma.shape[-1:] == (3,) + + # Check that everything got unraveled correctly -- feature 0 should be strictly negative, feature 1 + # strictly positive + assert (idata["posterior"].beta.sel(feature=0).to_numpy() < 0).all() + assert (idata["posterior"].beta.sel(feature=1).to_numpy() > 0).all() + + +@pytest.mark.parametrize( + "transform_samples", + [True, False], + ids=["transformed", "untransformed"], +) +def test_fit_laplace(transform_samples): + with pm.Model() as simp_model: + mu = pm.Normal("mu", mu=3, sigma=0.5) + sigma = pm.Exponential("sigma", 1) + obs = pm.Normal( + "obs", + mu=mu, + sigma=sigma, + observed=np.random.default_rng().normal(loc=3, scale=1.5, size=(10000,)), + ) + + idata = fit_laplace( + optimize_method="trust-ncg", + use_grad=True, + use_hessp=True, + transform_samples=transform_samples, + optimizer_kwargs=dict(maxiter=100_000, tol=1e-100), + ) + + np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2,), 3), atol=0.1) + np.testing.assert_allclose( + np.mean(idata.posterior.sigma, axis=1), np.full((2,), 1.5), atol=0.1 + ) + + if transform_samples: + assert idata.fit.rows.values.tolist() == ["mu", "sigma_log__"] + np.testing.assert_allclose(idata.fit.mean_vector.values, np.array([3.0, 0.4]), atol=0.1) + else: + assert idata.fit.rows.values.tolist() == ["mu", "sigma"] + np.testing.assert_allclose(idata.fit.mean_vector.values, np.array([3.0, 1.5]), atol=0.1) From 07ebe40f49e9d286b6bbcd4093dc8ddc9da460b0 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 4 Dec 2024 02:48:30 +0800 Subject: [PATCH 12/21] Use absolute import in doctest --- pymc_experimental/inference/laplace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_experimental/inference/laplace.py b/pymc_experimental/inference/laplace.py index 80fd7873..b832475d 100644 --- a/pymc_experimental/inference/laplace.py +++ b/pymc_experimental/inference/laplace.py @@ -485,7 +485,7 @@ def fit_laplace( Examples -------- - >>> from inference.laplace import fit_laplace + >>> from pymc_experimental.inference.laplace import fit_laplace >>> import numpy as np >>> import pymc as pm >>> import arviz as az From b40e101103faa30c5823f8e82aa230b4b6b05120 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 4 Dec 2024 11:43:40 +0800 Subject: [PATCH 13/21] Fix imports --- pymc_experimental/inference/laplace.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/pymc_experimental/inference/laplace.py b/pymc_experimental/inference/laplace.py index b832475d..9ea64043 100644 --- a/pymc_experimental/inference/laplace.py +++ b/pymc_experimental/inference/laplace.py @@ -13,6 +13,8 @@ # limitations under the License. +import logging + from functools import reduce from itertools import product from typing import Literal @@ -26,13 +28,6 @@ from arviz import dict_to_dataset from better_optimize.constants import minimize_method -from inference.find_map import ( - _log, - _unconstrained_vector_to_constrained_rvs, - find_MAP, - get_nearest_psd, - scipy_optimize_funcs_from_loss, -) from pymc import DictToArrayBijection from pymc.backends.arviz import ( coords_and_dims_for_inferencedata, @@ -45,6 +40,15 @@ from pymc.util import get_default_varnames from scipy import stats +from pymc_experimental.inference.find_map import ( + _unconstrained_vector_to_constrained_rvs, + find_MAP, + get_nearest_psd, + scipy_optimize_funcs_from_loss, +) + +_log = logging.getLogger(__name__) + def laplace_draws_to_inferencedata( posterior_draws: list[np.ndarray[float | int]], model: pm.Model | None = None From bc340c257f191d0fa933d73fcec7cb74ace8467e Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 4 Dec 2024 12:46:10 +0800 Subject: [PATCH 14/21] Fix unrelated statespace test --- tests/statespace/test_ETS.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/statespace/test_ETS.py b/tests/statespace/test_ETS.py index b56a3581..b9c15e5e 100644 --- a/tests/statespace/test_ETS.py +++ b/tests/statespace/test_ETS.py @@ -408,4 +408,4 @@ def test_ETS_stationary_initialization(): R, Q = outputs["selection"], outputs["state_cov"] P0_expected = linalg.solve_discrete_lyapunov(T_stationary, R @ Q @ R.T) - assert_allclose(outputs["initial_state_cov"], P0_expected) + assert_allclose(outputs["initial_state_cov"], P0_expected, rtol=1e-8, atol=1e-8) From da338bf73754a957814119ac2b966f68bd66b6b2 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 4 Dec 2024 19:04:13 +0800 Subject: [PATCH 15/21] - Rename argument `use_jax_gradients` -> `gradient_backend` - Rename function `laplace` -> `sample_laplace_posterior` --- pymc_experimental/inference/find_map.py | 33 +++++++++++++++++++------ pymc_experimental/inference/laplace.py | 23 ++++++++++------- tests/test_jax_find_map.py | 13 +++++----- tests/test_laplace.py | 16 ++++++++---- 4 files changed, 57 insertions(+), 28 deletions(-) diff --git a/pymc_experimental/inference/find_map.py b/pymc_experimental/inference/find_map.py index 95877b27..d1044b20 100644 --- a/pymc_experimental/inference/find_map.py +++ b/pymc_experimental/inference/find_map.py @@ -1,7 +1,7 @@ import logging from collections.abc import Callable -from typing import cast +from typing import Literal, cast, get_args import jax import numpy as np @@ -17,11 +17,15 @@ from pymc.pytensorf import join_nonshared_inputs from pymc.util import get_default_varnames from pytensor.compile import Function +from pytensor.compile.mode import Mode from pytensor.tensor import TensorVariable from scipy.optimize import OptimizeResult _log = logging.getLogger(__name__) +GradientBackend = Literal["pytensor", "jax"] +VALID_BACKENDS = get_args(GradientBackend) + def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp): method_info = MINIMIZE_MODE_KWARGS[method].copy() @@ -85,7 +89,11 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model, out.append(untransformed_X) - f_untransform = pytensor.function([X], out, mode="JAX") + f_untransform = pytensor.function( + inputs=[pytensor.In(X, borrow=True)], + outputs=pytensor.Out(out, borrow=True), + mode=Mode(linker="py", optimizer=None), + ) return f_untransform(posterior_draws) @@ -209,7 +217,7 @@ def scipy_optimize_funcs_from_loss( use_grad: bool, use_hess: bool, use_hessp: bool, - use_jax_gradients: bool = False, + gradient_backend: GradientBackend = "pytensor", compile_kwargs: dict | None = None, ) -> tuple[Callable, ...]: """ @@ -230,8 +238,8 @@ def scipy_optimize_funcs_from_loss( Whether to compile a function that computes the Hessian of the loss function. use_hessp: bool Whether to compile a function that computes the Hessian-vector product of the loss function. - use_jax_gradients: bool - If True, use JAX to compute gradients. This is only possible when ``compile_kwargs["mode"]`` is set to "JAX". + gradient_backend: str, one of "jax" or "pytensor" + Which backend to use to compute gradients. compile_kwargs: Additional keyword arguments to pass to the ``pm.compile_pymc`` function. @@ -252,7 +260,12 @@ def scipy_optimize_funcs_from_loss( "Cannot compute hessian or hessian-vector product without also computing the gradient" ) - use_jax_gradients = use_jax_gradients and use_grad + if gradient_backend not in VALID_BACKENDS: + raise ValueError( + f"Invalid gradient backend: {gradient_backend}. Must be one of {VALID_BACKENDS}" + ) + + use_jax_gradients = (gradient_backend == "jax") and use_grad mode = compile_kwargs.get("mode", None) if mode is None and use_jax_gradients: @@ -307,7 +320,7 @@ def find_MAP( jitter_rvs: list[TensorVariable] | None = None, progressbar: bool = True, include_transformed: bool = True, - use_jax_gradients: bool = False, + gradient_backend: GradientBackend = "pytensor", compile_kwargs: dict | None = None, **optimizer_kwargs, ) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], OptimizeResult]: @@ -342,6 +355,10 @@ def find_MAP( Whether to display a progress bar during optimization. Defaults to True. include_transformed: bool, optional Whether to include transformed variable values in the returned dictionary. Defaults to True. + gradient_backend: str, default "pytensor" + Which backend to use to compute gradients. Must be one of "pytensor" or "jax". + compile_kwargs: dict, optional + Additional options to pass to the ``pytensor.function`` function when compiling loss functions. **optimizer_kwargs Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function. @@ -380,7 +397,7 @@ def find_MAP( use_grad=use_grad, use_hess=use_hess, use_hessp=use_hessp, - use_jax_gradients=use_jax_gradients, + gradient_backend=gradient_backend, compile_kwargs=compile_kwargs, ) diff --git a/pymc_experimental/inference/laplace.py b/pymc_experimental/inference/laplace.py index 9ea64043..2813ef12 100644 --- a/pymc_experimental/inference/laplace.py +++ b/pymc_experimental/inference/laplace.py @@ -41,6 +41,7 @@ from scipy import stats from pymc_experimental.inference.find_map import ( + GradientBackend, _unconstrained_vector_to_constrained_rvs, find_MAP, get_nearest_psd, @@ -235,7 +236,7 @@ def fit_mvn_to_MAP( model: pm.Model | None = None, on_bad_cov: Literal["warn", "error", "ignore"] = "ignore", transform_samples: bool = False, - use_jax_gradients: bool = False, + gradient_backend: GradientBackend = "pytensor", zero_tol: float = 1e-8, diag_jitter: float | None = 1e-8, compile_kwargs: dict | None = None, @@ -256,12 +257,16 @@ def fit_mvn_to_MAP( If 'error', an error will be raised. transform_samples : bool Whether to transform the samples back to the original parameter space. Default is True. + gradient_backend: str, default "pytensor" + The backend to use for gradient computations. Must be one of "pytensor" or "jax". zero_tol: float Value below which an element of the Hessian matrix is counted as 0. This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8. diag_jitter: float | None A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite. If None, no jitter is added. Default is 1e-8. + compile_kwargs: dict, optional + Additional keyword arguments to pass to pytensor.function when compiling loss functions Returns ------- @@ -294,7 +299,7 @@ def fit_mvn_to_MAP( use_grad=True, use_hess=True, use_hessp=False, - use_jax_gradients=use_jax_gradients, + gradient_backend=gradient_backend, compile_kwargs=compile_kwargs, ) @@ -323,7 +328,7 @@ def stabilize(x, jitter): return mu, H_inv -def laplace( +def sample_laplace_posterior( mu: RaveledVars, H_inv: np.ndarray, model: pm.Model | None = None, @@ -416,7 +421,7 @@ def fit_laplace( jitter_rvs: list[pt.TensorVariable] | None = None, progressbar: bool = True, include_transformed: bool = True, - use_jax_gradients: bool = False, + gradient_backend: GradientBackend = "pytensor", chains: int = 2, draws: int = 500, on_bad_cov: Literal["warn", "error", "ignore"] = "ignore", @@ -461,8 +466,8 @@ def fit_laplace( Whether to display a progress bar during optimization. Defaults to True. include_transformed: bool, optional Whether to include transformed variable values in the returned dictionary. Defaults to True. - use_jax_gradients: bool, optional - Whether to use JAX for gradient calculations. Defaults to False. + gradient_backend: str, default "pytensor" + The backend to use for gradient computations. Must be one of "pytensor" or "jax". chains: int, default: 2 The number of sampling chains running in parallel. draws: int, default: 500 @@ -489,7 +494,7 @@ def fit_laplace( Examples -------- - >>> from pymc_experimental.inference.laplace import fit_laplace + >>> from pymc_experimental.inference.sample_laplace_posterior import fit_laplace >>> import numpy as np >>> import pymc as pm >>> import arviz as az @@ -526,7 +531,7 @@ def fit_laplace( jitter_rvs=jitter_rvs, progressbar=progressbar, include_transformed=include_transformed, - use_jax_gradients=use_jax_gradients, + gradient_backend=gradient_backend, compile_kwargs=compile_kwargs, **optimizer_kwargs, ) @@ -541,7 +546,7 @@ def fit_laplace( compile_kwargs=compile_kwargs, ) - return laplace( + return sample_laplace_posterior( mu=mu, H_inv=H_inv, model=model, diff --git a/tests/test_jax_find_map.py b/tests/test_jax_find_map.py index cba9c503..6b2c029a 100644 --- a/tests/test_jax_find_map.py +++ b/tests/test_jax_find_map.py @@ -4,6 +4,7 @@ import pytest from pymc_experimental.inference.find_map import ( + GradientBackend, find_MAP, scipy_optimize_funcs_from_loss, ) @@ -17,8 +18,8 @@ def rng(): return np.random.default_rng(seed) -@pytest.mark.parametrize("use_jax_gradients", [True, False], ids=["jax_grad", "pt_grad"]) -def test_jax_functions_from_graph(use_jax_gradients): +@pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str) +def test_jax_functions_from_graph(gradient_backend: GradientBackend): x = pt.tensor("x", shape=(2,)) def compute_z(x): @@ -34,7 +35,7 @@ def compute_z(x): use_grad=True, use_hess=True, use_hessp=True, - use_jax_gradients=use_jax_gradients, + gradient_backend=gradient_backend, compile_kwargs=dict(mode="JAX"), ) @@ -69,8 +70,8 @@ def compute_z(x): ("trust-constr", True, True), ], ) -@pytest.mark.parametrize("use_jax_gradients", [True, False], ids=["jax_grad", "pt_grad"]) -def test_JAX_map(method, use_grad, use_hess, use_jax_gradients, rng): +@pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str) +def test_JAX_map(method, use_grad, use_hess, gradient_backend: GradientBackend, rng): extra_kwargs = {} if method == "dogleg": # HACK -- dogleg requires that the hessian of the objective function is PSD, so we have to pick a point @@ -88,7 +89,7 @@ def test_JAX_map(method, use_grad, use_hess, use_jax_gradients, rng): use_grad=use_grad, use_hess=use_hess, progressbar=False, - use_jax_gradients=use_jax_gradients, + gradient_backend=gradient_backend, compile_kwargs={"mode": "JAX"}, ) mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"] diff --git a/tests/test_laplace.py b/tests/test_laplace.py index 424e41e4..811ab51f 100644 --- a/tests/test_laplace.py +++ b/tests/test_laplace.py @@ -20,7 +20,11 @@ import pymc_experimental as pmx from pymc_experimental.inference.find_map import find_MAP -from pymc_experimental.inference.laplace import fit_laplace, fit_mvn_to_MAP, laplace +from pymc_experimental.inference.laplace import ( + fit_laplace, + fit_mvn_to_MAP, + sample_laplace_posterior, +) @pytest.fixture(scope="session") @@ -86,7 +90,7 @@ def test_laplace_only_fit(): method="laplace", optimize_method="BFGS", progressbar=True, - use_jax_gradients=True, + gradient_backend="jax", compile_kwargs={"mode": "JAX"}, optimizer_kwargs=dict(maxiter=100_000, gtol=1e-100), random_seed=173300, @@ -127,7 +131,7 @@ def test_fit_laplace_coords(rng, transform_samples, mode): use_hessp=True, progressbar=False, compile_kwargs=dict(mode=mode), - use_jax_gradients=mode == "JAX", + gradient_backend="jax" if mode == "JAX" else "pytensor", ) for value in optimized_point.values(): @@ -139,7 +143,9 @@ def test_fit_laplace_coords(rng, transform_samples, mode): transform_samples=transform_samples, ) - idata = laplace(mu=mu, H_inv=H_inv, model=model, transform_samples=transform_samples) + idata = sample_laplace_posterior( + mu=mu, H_inv=H_inv, model=model, transform_samples=transform_samples + ) np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2, 3), 3), atol=0.5) np.testing.assert_allclose( @@ -182,7 +188,7 @@ def test_fit_laplace_ragged_coords(rng): progressbar=False, use_grad=True, use_hessp=True, - use_jax_gradients=True, + gradient_backend="jax", compile_kwargs={"mode": "JAX"}, ) From 3ebbf20646f63cd1cb433894abd00dd80d6bea32 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 4 Dec 2024 19:05:46 +0800 Subject: [PATCH 16/21] Fix typo introduced by rename refactor --- pymc_experimental/inference/laplace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_experimental/inference/laplace.py b/pymc_experimental/inference/laplace.py index 2813ef12..39a7d220 100644 --- a/pymc_experimental/inference/laplace.py +++ b/pymc_experimental/inference/laplace.py @@ -494,7 +494,7 @@ def fit_laplace( Examples -------- - >>> from pymc_experimental.inference.sample_laplace_posterior import fit_laplace + >>> from pymc_experimental.inference.laplace import fit_laplace >>> import numpy as np >>> import pymc as pm >>> import arviz as az From 2035202dd72f976c1e347d2276fc7ca954b29675 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 4 Dec 2024 19:09:07 +0800 Subject: [PATCH 17/21] use `mode=FAST_COMPILE` to get `unobserved_value_vars` after MAP optimization --- pymc_experimental/inference/find_map.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_experimental/inference/find_map.py b/pymc_experimental/inference/find_map.py index d1044b20..ff1d5b68 100644 --- a/pymc_experimental/inference/find_map.py +++ b/pymc_experimental/inference/find_map.py @@ -418,7 +418,7 @@ def find_MAP( raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info) unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed) - unobserved_vars_values = model.compile_fn(unobserved_vars)( + unobserved_vars_values = model.compile_fn(unobserved_vars, mode="FAST_COMPILE")( DictToArrayBijection.rmap(raveled_optimized) ) From f2504e98addc8c082038866dac3ead34925eecbc Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 4 Dec 2024 19:09:45 +0800 Subject: [PATCH 18/21] Rename `test_jax_find_map.py` -> `test_find_map.py` --- tests/{test_jax_find_map.py => test_find_map.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{test_jax_find_map.py => test_find_map.py} (100%) diff --git a/tests/test_jax_find_map.py b/tests/test_find_map.py similarity index 100% rename from tests/test_jax_find_map.py rename to tests/test_find_map.py From a81079b28748a963993b5867d0db4dc913059bbf Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 4 Dec 2024 19:19:12 +0800 Subject: [PATCH 19/21] Improve docstring for `fit_laplace` --- pymc_experimental/inference/laplace.py | 27 ++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/pymc_experimental/inference/laplace.py b/pymc_experimental/inference/laplace.py index 39a7d220..24a72c0f 100644 --- a/pymc_experimental/inference/laplace.py +++ b/pymc_experimental/inference/laplace.py @@ -425,7 +425,7 @@ def fit_laplace( chains: int = 2, draws: int = 500, on_bad_cov: Literal["warn", "error", "ignore"] = "ignore", - transform_samples: bool = False, + fit_in_unconstrained_space: bool = False, zero_tol: float = 1e-8, diag_jitter: float | None = 1e-8, optimizer_kwargs: dict | None = None, @@ -464,8 +464,17 @@ def fit_laplace( Variables whose initial values should be jittered. If None, all variables are jittered. progressbar : bool, optional Whether to display a progress bar during optimization. Defaults to True. - include_transformed: bool, optional - Whether to include transformed variable values in the returned dictionary. Defaults to True. + fit_in_unconstrained_space: bool, default False + Whether to fit the Laplace approximation in the unconstrained parameter space. If True, samples will be drawn + from a mean and covariance matrix computed at a point in the **unconstrained** parameter space. Samples will + then be transformed back to the original parameter space. This will guarantee that the samples will respect + the domain of prior distributions (for exmaple, samples from a Beta distribution will be strictly between 0 + and 1). + + .. warning:: + This argumnet should be considered highly experimental. It has not been verified if this method produces + valid draws from the posterior. **Use at your own risk**. + gradient_backend: str, default "pytensor" The backend to use for gradient computations. Must be one of "pytensor" or "jax". chains: int, default: 2 @@ -476,15 +485,17 @@ def fit_laplace( What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite. If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned. If 'error', an error will be raised. - transform_samples : bool - Whether to transform the samples back to the original parameter space. Default is True. zero_tol: float Value below which an element of the Hessian matrix is counted as 0. This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8. diag_jitter: float | None A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite. If None, no jitter is added. Default is 1e-8. - compile_kwargs: optional + optimizer_kwargs: dict, optional + Additional keyword arguments to pass to scipy.minimize. See the documentation for scipy.optimize.minimize for + details. Arguments that are typically passed via ``options`` will be automatically extracted without the need + to use a nested dictionary. + compile_kwargs: dict, optional Additional keyword arguments to pass to pytensor.function. Returns @@ -540,7 +551,7 @@ def fit_laplace( optimized_point=optimized_point, model=model, on_bad_cov=on_bad_cov, - transform_samples=transform_samples, + transform_samples=fit_in_unconstrained_space, zero_tol=zero_tol, diag_jitter=diag_jitter, compile_kwargs=compile_kwargs, @@ -552,7 +563,7 @@ def fit_laplace( model=model, chains=chains, draws=draws, - transform_samples=transform_samples, + transform_samples=fit_in_unconstrained_space, progressbar=progressbar, random_seed=random_seed, compile_kwargs=compile_kwargs, From 4d88343e77ba6193fa2cb60cc1a1edb3a4f24d81 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 4 Dec 2024 19:21:06 +0800 Subject: [PATCH 20/21] Update tests to match new signature --- tests/test_laplace.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_laplace.py b/tests/test_laplace.py index 811ab51f..a11ee59e 100644 --- a/tests/test_laplace.py +++ b/tests/test_laplace.py @@ -202,11 +202,11 @@ def test_fit_laplace_ragged_coords(rng): @pytest.mark.parametrize( - "transform_samples", + "fit_in_unconstrained_space", [True, False], ids=["transformed", "untransformed"], ) -def test_fit_laplace(transform_samples): +def test_fit_laplace(fit_in_unconstrained_space): with pm.Model() as simp_model: mu = pm.Normal("mu", mu=3, sigma=0.5) sigma = pm.Exponential("sigma", 1) @@ -221,7 +221,7 @@ def test_fit_laplace(transform_samples): optimize_method="trust-ncg", use_grad=True, use_hessp=True, - transform_samples=transform_samples, + fit_in_unconstrained_space=fit_in_unconstrained_space, optimizer_kwargs=dict(maxiter=100_000, tol=1e-100), ) @@ -230,7 +230,7 @@ def test_fit_laplace(transform_samples): np.mean(idata.posterior.sigma, axis=1), np.full((2,), 1.5), atol=0.1 ) - if transform_samples: + if fit_in_unconstrained_space: assert idata.fit.rows.values.tolist() == ["mu", "sigma_log__"] np.testing.assert_allclose(idata.fit.mean_vector.values, np.array([3.0, 0.4]), atol=0.1) else: From 9b1cd0ed7c9f66840d3c2b2c99aef430ca7e8fce Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 4 Dec 2024 19:28:03 +0800 Subject: [PATCH 21/21] Update docstring --- pymc_experimental/inference/find_map.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pymc_experimental/inference/find_map.py b/pymc_experimental/inference/find_map.py index ff1d5b68..72ce3b19 100644 --- a/pymc_experimental/inference/find_map.py +++ b/pymc_experimental/inference/find_map.py @@ -92,7 +92,7 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model, f_untransform = pytensor.function( inputs=[pytensor.In(X, borrow=True)], outputs=pytensor.Out(out, borrow=True), - mode=Mode(linker="py", optimizer=None), + mode=Mode(linker="py", optimizer="FAST_COMPILE"), ) return f_untransform(posterior_draws) @@ -223,7 +223,6 @@ def scipy_optimize_funcs_from_loss( """ Compile loss functions for use with scipy.optimize.minimize. - Parameters ---------- loss: TensorVariable @@ -238,8 +237,8 @@ def scipy_optimize_funcs_from_loss( Whether to compile a function that computes the Hessian of the loss function. use_hessp: bool Whether to compile a function that computes the Hessian-vector product of the loss function. - gradient_backend: str, one of "jax" or "pytensor" - Which backend to use to compute gradients. + gradient_backend: str, default "pytensor" + Which backend to use to compute gradients. Must be one of "jax" or "pytensor" compile_kwargs: Additional keyword arguments to pass to the ``pm.compile_pymc`` function.