Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add find_MAP with close JAX integration and fix bug with Laplace fit #385

Merged
merged 21 commits into from
Dec 4, 2024
Merged
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
6aa20f7
Add JAX-based `find_MAP`
jessegrabowski Oct 27, 2024
7ed3b2f
add `better_optimize` to CI envs
jessegrabowski Oct 27, 2024
e412f6f
Fix relative import
jessegrabowski Oct 27, 2024
f9b6258
Remove `find_MAP` import from module-level `__init__.py`
jessegrabowski Oct 27, 2024
ad3abd9
Update docstring
jessegrabowski Oct 27, 2024
be1d790
Allow calling `find_MAP` inside model context without model argument
jessegrabowski Oct 27, 2024
923eb26
Required patched better_optimize
jessegrabowski Oct 27, 2024
f705d43
in-progress refactor
jessegrabowski Nov 30, 2024
a23762b
More refactor
jessegrabowski Dec 3, 2024
2d21403
Generalize code to use any pytensor backend
jessegrabowski Dec 3, 2024
4c2529d
Reconcile the two laplace approximation functions
jessegrabowski Dec 3, 2024
07ebe40
Use absolute import in doctest
jessegrabowski Dec 3, 2024
b40e101
Fix imports
jessegrabowski Dec 4, 2024
bc340c2
Fix unrelated statespace test
jessegrabowski Dec 4, 2024
da338bf
- Rename argument `use_jax_gradients` -> `gradient_backend`
jessegrabowski Dec 4, 2024
3ebbf20
Fix typo introduced by rename refactor
jessegrabowski Dec 4, 2024
2035202
use `mode=FAST_COMPILE` to get `unobserved_value_vars` after MAP opti…
jessegrabowski Dec 4, 2024
f2504e9
Rename `test_jax_find_map.py` -> `test_find_map.py`
jessegrabowski Dec 4, 2024
a81079b
Improve docstring for `fit_laplace`
jessegrabowski Dec 4, 2024
4d88343
Update tests to match new signature
jessegrabowski Dec 4, 2024
9b1cd0e
Update docstring
jessegrabowski Dec 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Reconcile the two laplace approximation functions
  • Loading branch information
jessegrabowski committed Dec 3, 2024
commit 4c2529d67290dd00be54cb104dffcc4b1ca084a2
277 changes: 22 additions & 255 deletions pymc_experimental/inference/find_map.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,41 @@
import logging

from collections.abc import Callable
from typing import Literal, cast
from typing import cast

import arviz as az
import jax
import numpy as np
import pymc as pm
import pytensor
import pytensor.tensor as pt
import xarray as xr

from arviz import dict_to_dataset
from better_optimize import minimize
from better_optimize.constants import minimize_method
from pymc.backends.arviz import (
coords_and_dims_for_inferencedata,
find_constants,
find_observations,
)
from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method
from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.initial_point import make_initial_point_fn
from pymc.model.transform.conditioning import remove_value_transforms
from pymc.model.transform.optimization import freeze_dims_and_data
from pymc.pytensorf import join_nonshared_inputs
from pymc.util import get_default_varnames
from pytensor.compile import Function
from pytensor.tensor import TensorVariable
from scipy import stats
from scipy.optimize import OptimizeResult

_log = logging.getLogger(__name__)


def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
method_info = MINIMIZE_MODE_KWARGS[method].copy()

use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
use_hess = use_hess if use_hess is not None else method_info["uses_hess"]
use_hessp = use_hessp if use_hessp is not None else method_info["uses_hessp"]

if use_hess and use_hessp:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going through all the methods thinking when you would need hess and hessp and then came back to this. I would probably warn the user / not let them pass both use_hess and use_hessp

use_hess = False

return use_grad, use_hess, use_hessp


def get_nearest_psd(A: np.ndarray) -> np.ndarray:
"""
Compute the nearest positive semi-definite matrix to a given matrix.
@@ -60,7 +63,9 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray:

def _unconstrained_vector_to_constrained_rvs(model):
constrained_rvs, unconstrained_vector = join_nonshared_inputs(
model.initial_point(), inputs=model.value_vars, outputs=model.unobserved_value_vars
model.initial_point(),
inputs=model.value_vars,
outputs=get_default_varnames(model.unobserved_value_vars, include_transformed=False),
)

unconstrained_vector.name = "unconstrained_vector"
@@ -289,247 +294,6 @@ def scipy_optimize_funcs_from_loss(
return f_loss, f_hess, f_hessp


def fit_mvn_to_MAP(
optimized_point: dict[str, np.ndarray],
model: pm.Model,
on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
transform_samples: bool = True,
use_jax_gradients: bool = False,
zero_tol: float = 1e-8,
diag_jitter: float | None = 1e-8,
compile_kwargs: dict | None = None,
) -> tuple[RaveledVars, np.ndarray]:
"""
Create a multivariate normal distribution using the inverse of the negative Hessian matrix of the log-posterior
evaluated at the MAP estimate. This is the basis of the Laplace approximation.

Parameters
----------
optimized_point : dict[str, np.ndarray]
Local maximum a posteriori (MAP) point returned from pymc.find_MAP or jax_tools.fit_map
model : Model
A PyMC model
on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore'
What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite.
If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned.
If 'error', an error will be raised.
transform_samples : bool
Whether to transform the samples back to the original parameter space. Default is True.
zero_tol: float
Value below which an element of the Hessian matrix is counted as 0.
This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8.
diag_jitter: float | None
A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite.
If None, no jitter is added. Default is 1e-8.

Returns
-------
map_estimate: RaveledVars
The MAP estimate of the model parameters, raveled into a 1D array.

inverse_hessian: np.ndarray
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
"""
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
frozen_model = freeze_dims_and_data(model)

if not transform_samples:
untransformed_model = remove_value_transforms(frozen_model)
logp = untransformed_model.logp(jacobian=False)
variables = untransformed_model.continuous_value_vars
else:
logp = frozen_model.logp(jacobian=True)
variables = frozen_model.continuous_value_vars

variable_names = {var.name for var in variables}
optimized_free_params = {k: v for k, v in optimized_point.items() if k in variable_names}
mu = DictToArrayBijection.map(optimized_free_params)

_, f_hess, _ = scipy_optimize_funcs_from_loss(
loss=-logp,
inputs=variables,
initial_point_dict=frozen_model.initial_point(),
use_grad=True,
use_hess=True,
use_hessp=False,
use_jax_gradients=use_jax_gradients,
compile_kwargs=compile_kwargs,
)

H = -f_hess(mu.data)
H_inv = np.linalg.pinv(np.where(np.abs(H) < zero_tol, 0, -H))

def stabilize(x, jitter):
return x + np.eye(x.shape[0]) * jitter

H_inv = H_inv if diag_jitter is None else stabilize(H_inv, diag_jitter)

try:
np.linalg.cholesky(H_inv)
except np.linalg.LinAlgError:
if on_bad_cov == "error":
raise np.linalg.LinAlgError(
"Inverse Hessian not positive-semi definite at the provided point"
)
H_inv = get_nearest_psd(H_inv)
if on_bad_cov == "warn":
_log.warning(
"Inverse Hessian is not positive semi-definite at the provided point, using the closest PSD "
"matrix in L1-norm instead"
)

return mu, H_inv


def laplace(
mu: RaveledVars,
H_inv: np.ndarray,
model: pm.Model,
chains: int = 2,
draws: int = 500,
transform_samples: bool = True,
progressbar: bool = True,
**compile_kwargs,
) -> az.InferenceData:
"""

Parameters
----------
mu
H_inv
model : Model
A PyMC model
chains : int
The number of sampling chains running in parallel. Default is 2.
draws : int
The number of samples to draw from the approximated posterior. Default is 500.
transform_samples : bool
Whether to transform the samples back to the original parameter space. Default is True.

Returns
-------
idata: az.InferenceData
An InferenceData object containing the approximated posterior samples.
"""
posterior_dist = stats.multivariate_normal(mean=mu.data, cov=H_inv, allow_singular=True)
posterior_draws = posterior_dist.rvs(size=(chains, draws))

if transform_samples:
constrained_rvs, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(model)
batched_values = pt.tensor(
"batched_values",
shape=(chains, draws, *unconstrained_vector.type.shape),
dtype=unconstrained_vector.type.dtype,
)
batched_rvs = pytensor.graph.vectorize_graph(
constrained_rvs, replace={unconstrained_vector: batched_values}
)

f_constrain = pm.compile_pymc(
inputs=[batched_values], outputs=batched_rvs, **compile_kwargs
)
posterior_draws = f_constrain(posterior_draws)

else:
info = mu.point_map_info
flat_shapes = [np.prod(shape).astype(int) for _, shape, _ in info]
slices = [
slice(sum(flat_shapes[:i]), sum(flat_shapes[: i + 1])) for i in range(len(flat_shapes))
]

posterior_draws = [
posterior_draws[..., idx].reshape((chains, draws, *shape)).astype(dtype)
for idx, (name, shape, dtype) in zip(slices, info)
]

def make_rv_coords(name):
coords = {"chain": range(chains), "draw": range(draws)}
extra_dims = model.named_vars_to_dims.get(name)
if extra_dims is None:
return coords
return coords | {dim: list(model.coords[dim]) for dim in extra_dims}

def make_rv_dims(name):
dims = ["chain", "draw"]
extra_dims = model.named_vars_to_dims.get(name)
if extra_dims is None:
return dims
return dims + list(extra_dims)

idata = {
name: xr.DataArray(
data=draws.squeeze(),
coords=make_rv_coords(name),
dims=make_rv_dims(name),
name=name,
)
for (name, _, _), draws in zip(mu.point_map_info, posterior_draws)
}

coords, dims = coords_and_dims_for_inferencedata(model)
idata = az.convert_to_inference_data(idata, coords=coords, dims=dims)

if model.deterministics:
idata.posterior = pm.compute_deterministics(
idata.posterior,
model=model,
merge_dataset=True,
progressbar=progressbar,
compile_kwargs=compile_kwargs,
)

observed_data = dict_to_dataset(
find_observations(model),
library=pm,
coords=coords,
dims=dims,
default_dims=[],
)

constant_data = dict_to_dataset(
find_constants(model),
library=pm,
coords=coords,
dims=dims,
default_dims=[],
)

idata.add_groups(
{"observed_data": observed_data, "constant_data": constant_data},
coords=coords,
dims=dims,
)

return idata


def fit_laplace(
optimized_point: dict[str, np.ndarray],
model: pm.Model,
chains: int = 2,
draws: int = 500,
on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
transform_samples: bool = True,
zero_tol: float = 1e-8,
diag_jitter: float | None = 1e-8,
progressbar: bool = True,
compile_kwargs: dict | None = None,
) -> az.InferenceData:
compile_kwargs = {} if compile_kwargs is None else compile_kwargs

mu, H_inv = fit_mvn_to_MAP(
optimized_point=optimized_point,
model=model,
on_bad_cov=on_bad_cov,
transform_samples=transform_samples,
zero_tol=zero_tol,
diag_jitter=diag_jitter,
compile_kwargs=compile_kwargs,
)

return laplace(mu, H_inv, model, chains, draws, transform_samples, progressbar)


def find_MAP(
method: minimize_method,
*,
@@ -605,9 +369,12 @@ def find_MAP(
initial_params = DictToArrayBijection.map(
{var_name: value for var_name, value in start_dict.items() if var_name in vars_dict}
)
use_grad, use_hess, use_hessp = set_optimizer_function_defaults(
method, use_grad, use_hess, use_hessp
)

f_logp, f_hess, f_hessp = scipy_optimize_funcs_from_loss(
loss=-frozen_model.logp(),
loss=-frozen_model.logp(jacobian=False),
inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars,
initial_point_dict=start_dict,
use_grad=use_grad,
4 changes: 2 additions & 2 deletions pymc_experimental/inference/fit.py
Original file line number Diff line number Diff line change
@@ -39,6 +39,6 @@ def fit(method, **kwargs):
return fit_pathfinder(**kwargs)

if method == "laplace":
from pymc_experimental.inference.laplace import laplace
from pymc_experimental.inference.laplace import fit_laplace

return laplace(**kwargs)
return fit_laplace(**kwargs)
Loading
Loading