Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

icdf of beta distribution is not implemented #1365

Closed
ayaka14732 opened this issue Mar 15, 2022 · 7 comments · Fixed by #1478
Closed

icdf of beta distribution is not implemented #1365

ayaka14732 opened this issue Mar 15, 2022 · 7 comments · Fixed by #1478
Labels
enhancement New feature or request good first issue Good for newcomers

Comments

@ayaka14732
Copy link

>>> import numpyro.distributions as dist
>>> d = dist.Beta(1, 5)
>>> d.icdf(0.975)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ayaka/venv/lib/python3.10/site-packages/numpyro/distributions/distribution.py", line 463, in icdf
    raise NotImplementedError
NotImplementedError

NumPyro version: 0.9.1

@fehiepsi fehiepsi added the jax This issue is specific to JAX label Mar 15, 2022
@fehiepsi
Copy link
Member

fehiepsi commented Mar 15, 2022

Hi @ayaka14732, please subscribe this upstream issue jax-ml/jax#2399 Just curious, what is your usage case? If you don't need differentiation, then we can use scipy implementation for it.

@ayaka14732
Copy link
Author

Just curious, what is your usage case? If you don't need differentiation, then we can use scipy implementation for it.

Just for learning, so I can use the scipy implementation. But for the NumPyro library, I think we should wait for it to be added in JAX, so users won't experience unexpected performance degradation.

@e-pet
Copy link
Contributor

e-pet commented Jul 16, 2022

I'd like to use the beta icdf in a pyro.deterministic statement in a model guide. Do I understand correctly that that is not possible by simply using the scipy function because gradient information is required? (I tried doing so and got errors that seemed to indicate this.)

Any recommendations for what would currently be the best / easiest way to circumvent this problem? Can I somehow implement and provide the gradient information by myself?

@fehiepsi
Copy link
Member

I guess if beta icdf is only used at deterministic sites for prediction, then you can use host_callback as in TruncatedDistribution tutorial. If it is used at some downstream observation sites, then we need grad rules for it. Please see this tutorial for more information.

@e-pet
Copy link
Contributor

e-pet commented Jul 30, 2022

(In case somebody else finds this and looks for a solution, I opted to use the Kumaraswamy distribution instead, as recommended in the jax issue linked to above. It's very similar / in some cases identical to the beta distribution and has the icdf implemented.)

@fehiepsi fehiepsi added enhancement New feature or request good first issue Good for newcomers and removed jax This issue is specific to JAX labels Aug 29, 2022
@fehiepsi
Copy link
Member

As suggested in jax-ml/jax#2399, we can use tfp.math.betaincinv in the implementation of this cdf.

@fehiepsi
Copy link
Member

fehiepsi commented Sep 6, 2022

This can be added in the same way as #1475 .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants