Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

truncated(d, l, Inf) fails with AD #1910

Open
penelopeysm opened this issue Nov 5, 2024 · 3 comments
Open

truncated(d, l, Inf) fails with AD #1910

penelopeysm opened this issue Nov 5, 2024 · 3 comments

Comments

@penelopeysm
Copy link

Using truncated with ±Inf as the bounds tends to lead to NaN's when using automatic differentiation:

julia> using Distributions; f(s) = logpdf(truncated(Normal(0.0, s[1]), 0, +Inf), 2)
f (generic function with 1 method)

julia> import ForwardDiff; ForwardDiff.gradient(f, [0.1])
1-element Vector{Float64}:
 NaN

The fact that NaN's are returned isn't so much a problem on its own, the issue is more that it creates an easy but non-obvious trap for users to fall into - e.g. in #1189 but I've also seen it on other Turing.jl stuff.

A cheap fix might be:

diff --git a/src/truncate.jl b/src/truncate.jl
index 48d62b01..bf8d379e 100644
--- a/src/truncate.jl
+++ b/src/truncate.jl
@@ -62,6 +62,8 @@ end
 truncated(d::UnivariateDistribution, ::Nothing, ::Nothing) = d
 function truncated(d::UnivariateDistribution, l::T, u::T) where {T <: Real}
     l <= u || error("the lower bound must be less or equal than the upper bound")
+    l == -Inf && return truncated(d, nothing, u)
+    u == Inf && return truncated(d, l, nothing)
 
     # (log)lcdf = (log) P(X < l) where X ~ d
     loglcdf = _logcdf_noninclusive(d, l)

I recognise that in principle it isn't really the job of Distributions to fix this, but the patch is so small that it shouldn't be a maintenance burden, so I figured it was worth suggesting 🙂

@devmotion
Copy link
Member

This is a known issue and was suggested in e.g. #1730 (also related: #1467). I think the type instability would be a quite unfortunate consequence of such a change. Maybe an alternative that would avoid this problem would be to throw an exception or show a warning when a non-nothing bound is equal to the endpoint of the support.

@penelopeysm
Copy link
Author

Hmm, I clearly didn't dig back far enough when searching issues. Agreed on the importance of type stability, especially given the context that this isn't an issue with Distributions itself.

Even a warning feels a bit like a suboptimal solution? as I assume truncated distributions with +-Inf work perfectly fine within Distributions itself and the warning would be noise to anyone who was just doing that – it's only the usage with other packages where one might like to be warned.

@devmotion
Copy link
Member

It's generally always preferable to use nothing instead of an endpoint of the untruncated distribution. The former allows to optimize to calculations with, and possibly even the returned type of, truncated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants