Skip to content

Commit

Permalink
Add find_MAP with close JAX integration and fix bug with Laplace fit (
Browse files Browse the repository at this point in the history
#385)

* Add JAX-based `find_MAP`

* add `better_optimize` to CI envs

* Fix relative import

* Remove `find_MAP` import from module-level `__init__.py`

* Update docstring

* Allow calling `find_MAP` inside model context without model argument

* Required patched better_optimize

* in-progress refactor

* More refactor

* Generalize code to use any pytensor backend

* Reconcile the two laplace approximation functions

* Use absolute import in doctest

* Fix imports

* Fix unrelated statespace test

* - Rename argument `use_jax_gradients` -> `gradient_backend`
- Rename function `laplace` -> `sample_laplace_posterior`

* Fix typo introduced by rename refactor

* use `mode=FAST_COMPILE` to get `unobserved_value_vars` after MAP optimization

* Rename `test_jax_find_map.py` -> `test_find_map.py`

* Improve docstring for `fit_laplace`

* Update tests to match new signature

* Update docstring
  • Loading branch information
jessegrabowski authored Dec 4, 2024
1 parent 40714de commit 5055262
Show file tree
Hide file tree
Showing 8 changed files with 1,178 additions and 163 deletions.
1 change: 1 addition & 0 deletions conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ dependencies:
- pymc>=5.17.0 # CI was failing to resolve
- blackjax
- scikit-learn
- better_optimize>=0.0.10
1 change: 1 addition & 0 deletions conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ dependencies:
- pymc>=5.17.0 # CI was failing to resolve
- blackjax
- scikit-learn
- better_optimize>=0.0.10
Loading

0 comments on commit 5055262

Please sign in to comment.