Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Laplace (quadratic) approximation #345

Merged
merged 21 commits into from
Jul 1, 2024

Conversation

carsten-j
Copy link
Contributor

This is an early version of q quadratic approximation implementation that I have developed while reading Statistical Rethinking by Richard McElreath.

There is a short discussion about this in the issue and maybe @theorashid can help with feedback of this draft PR.

This work is partly based on the Python package pymc3-quap but pymc3-quap is based on PYMC3 and a lot happend bewteen version 3 and 5 of PYMC. Optimizers works better when provided with a good initial guess and hence a (optional) starting point has been added to function arguments. Please see Github for a discussion about the differences between PYMC version 3 and 5 for computing the Hessian.

@carsten-j
Copy link
Contributor Author

I am looking for the best way to return not just a posterior sample distribution but also the mean vector and covariance matrix of the Gaussian distribution. Any suggestion for this. So far my only idea is to add another section to the inferenceData returned containing this information. Thoughts on this?

@zaxtax
Copy link
Contributor

zaxtax commented Jun 2, 2024 via email

@aloctavodia
Copy link
Member

Historically, Inferencedata has been focused on mcmc. But we have discussed a few times extend it to better handle other inference methods, like SMC or variational methods. It just that there has not been enough momentum to agree and implement and schema that works for those methods.

@carsten-j
Copy link
Contributor Author

@zaxtax and @aloctavodia are you saying that I should not return inferencedata at all or just not return the gaussian mean and covariance in the inferencedata object? I am new to both PYMC and Bayesian statistics so I do not know the history of this package.
Best, Carsten

@zaxtax
Copy link
Contributor

zaxtax commented Jun 2, 2024 via email

@twiecki
Copy link
Member

twiecki commented Jun 3, 2024

CC @ferrine

@aloctavodia
Copy link
Member

Oh, it's more that we haven't decided how to handle this within the library. Don't treat this as a blocker, though we should raise it for discussion more broadly

exactly, just saying that if necessary InferenceData can be extended.

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 5, 2024

Suggestion, include two groups in the returned inferencedata:

  1. A fit group that includes the mean and covariance of the laplace fit, and
  2. a posterior group that includes draws from this fit, and has all the bells and whistles like dimensions, deterministics, etc... Include the default extra groups like observed and constant data. This will look just like a fit from mcmc sampling. This can be disabled by the user by setting draws = 0

We could even try different fits from distinct initialization points (optionally) and save those as distinct "chains" in the fit and corresponding posterior groups. Although usually multiple initialization are used with the goal of finding the best fit, they could still be useful to detect multi-modality / pervasiveness of local optima.

@ricardoV94
Copy link
Member

@carsten-j PR looks great! I left some comment above

@ricardoV94 ricardoV94 added the enhancements New feature or request label Jun 5, 2024
@carsten-j
Copy link
Contributor Author

Thanks you @ricardoV94 and @twiecki for the review comments. I believe that all of them expect one has been fixed. I have not figured out how to use remove_value_transforms. I tried to browse through PYMC source code but that did not really help.

@ricardoV94
Copy link
Member

Thanks you @ricardoV94 and @twiecki for the review comments. I believe that all of them expect one has been fixed. I have not figured out how to use remove_value_transforms. I tried to browse through PYMC source code but that did not really help.

The docs contains code example: https://www.pymc.io/projects/docs/en/stable/api/model/generated/pymc.model.transform.conditioning.remove_value_transforms.html

@carsten-j
Copy link
Contributor Author

I should have mentioned that I did read the doc and looked at the example. But I have not been able to figure out how to apply it to my case. I will try again ...

@ricardoV94
Copy link
Member

To be able to use it inside the model context, it will need this change to get merged first: pymc-devs/pymc#7352

But you should be able to already test by doing the object way with pm.fit(..., model=model) outside of the model context

@carsten-j
Copy link
Contributor Author

@ricardoV94 I figured out how to replace the for loop with remove_value_transforms. Is the PR ready for merge or are there additional review comments?

logsigma = pm.Uniform("logsigma", 1, 100)
mu = pm.Uniform("mu", -10000, 10000)
yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y)
vars = [mu, logsigma]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: say you only did vars=[mu], how would the variable logsigma be estimated?

Copy link
Member

@ricardoV94 ricardoV94 Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think find_MAP in that case uses the initial_point for the excluded variable(s). I never found that behavior useful tbh

Edit: Maybe it's fine. Either way it's documented here: https://github.com/pymc-devs/pymc/blob/05b557f6460a10c29c3db33690ee535f5b1ecde0/pymc/tuning/starting.py#L73-L75

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like we may want to pass a similar start kwarg to laplace to set the value of variables that are not being optimized?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

worth adding a test on this to confirm the behaviour

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I fully understand this. I will give it a second go with the documentation for find_MAP.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Carsten, is there anything we can do to help get this over the line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @theorashid. I am not sure how to handle if only a subset of the variables are passed in, e.g. vars=[mu] and log_sigma is left out. If this should raise a warning I need some way of figuring out the number of model parameters and compare that with the number of parameters in vars. I am not sure how to determine the number of model parameters

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model.free_RVs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@theorashid and @ricardoV94, I committed an update that will raise a warning in case number of variables in vars does not equal number of model variables.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ricardoV94 ricardoV94 changed the title First draft of quadratic approximation Implement Laplace (quadratic) approximation Jun 29, 2024
@ricardoV94
Copy link
Member

@carsten-j tests are no longer failing in main. You can rebase/merge into your branch

ricardoV94 and others added 6 commits June 30, 2024 21:33
* Allow forward sampling of statespace models in JAX mode

Explicitly set data shape to avoid broadcasting error

Better handling of measurement error dims in `SARIMAX` models

Freeze auxiliary models before forward sampling

Bugfixes for posterior predictive sampling helpers

Allow specification of time dimension name when registering data

Save info about exogenous data for post-estimation tasks

Restore `_exog_data_info` member variable

Be more consistent with the names of filter outputs

* Adjust test suite to reflect API changes

Modify structural tests to accommodate deterministic models

Save kalman filter outputs to idata for statespace tests

Remove test related to `add_exogenous`

Adjust structural module tests

* Add JAX test suite

* Bug-fixes and changes to statespace distributions

Remove tests related to the `add_exogenous` method

Add dummy `MvNormalSVDRV` for forward jax sampling with `method="SVD"`

Dynamically generate `LinearGaussianStateSpaceRV` signature from inputs

Add signature and simple test for `SequenceMvNormal`

* Re-run example notebooks

* Add helper function to sample prior/posterior statespace matrices

* fix tests

* Wrap jax MvNormal rewrite in try/except block

* Don't use `action` keyword in `catch_warnings`

* Skip JAX test if `numpyro` is not installed

* Handle batch dims on `SequenceMvNormal`

* Remove unused batch_dim logic in SequenceMvNormal

* Restore `get_support_shape_1d` import
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@carsten-j
Copy link
Contributor Author

@ricardoV94 I have rebased the laplace branch but it looks like someone needs to approve Github worksflows.

@zaxtax
Copy link
Contributor

zaxtax commented Jul 1, 2024

Looks like there are still a few failing tests, but once those pass this is probably good to merge

@carsten-j
Copy link
Contributor Author

@zaxtax failing test has been fixed. Can you approve the waiting workflow?

@carsten-j
Copy link
Contributor Author

@zaxtax, all tests passed. Are you also able to merge the PR? Thanks.

@twiecki twiecki merged commit 87d4aea into pymc-devs:main Jul 1, 2024
8 checks passed
@twiecki
Copy link
Member

twiecki commented Jul 1, 2024

Congrats @carsten-j, this is a big one!

@carsten-j
Copy link
Contributor Author

Thank you @twiecki. Really happy to contribute and thanks to all those that helped. After the summer I will try to work on documentation for building and running locally. I took me some time to figure out how this works!

@zaxtax
Copy link
Contributor

zaxtax commented Jul 1, 2024

Congrats @carsten-j this is really neat!

@theorashid
Copy link
Contributor

Brilliant work @carsten-j . Hope to see you contribute to PyMC again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants