Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
find_MAP
with close JAX integration and fix bug with Laplace fit (
#385) * Add JAX-based `find_MAP` * add `better_optimize` to CI envs * Fix relative import * Remove `find_MAP` import from module-level `__init__.py` * Update docstring * Allow calling `find_MAP` inside model context without model argument * Required patched better_optimize * in-progress refactor * More refactor * Generalize code to use any pytensor backend * Reconcile the two laplace approximation functions * Use absolute import in doctest * Fix imports * Fix unrelated statespace test * - Rename argument `use_jax_gradients` -> `gradient_backend` - Rename function `laplace` -> `sample_laplace_posterior` * Fix typo introduced by rename refactor * use `mode=FAST_COMPILE` to get `unobserved_value_vars` after MAP optimization * Rename `test_jax_find_map.py` -> `test_find_map.py` * Improve docstring for `fit_laplace` * Update tests to match new signature * Update docstring
- Loading branch information