-
Notifications
You must be signed in to change notification settings - Fork 146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FloatNN(::Dual)
#538
base: master
Are you sure you want to change the base?
Add FloatNN(::Dual)
#538
Conversation
Codecov Report
@@ Coverage Diff @@
## master #538 +/- ##
==========================================
+ Coverage 84.83% 84.85% +0.01%
==========================================
Files 9 9
Lines 831 832 +1
==========================================
+ Hits 705 706 +1
Misses 126 126
Continue to review full report at Codecov.
|
Why does Zygote need this? |
It's the diagonal hessian: Error During Test at /Users/mzgubic/JuliaEnvs/Zygote.jl/test/utils.jl:22
Got exception outside of a @test
MethodError: no method matching Float64(::ForwardDiff.Dual{Nothing, Float64, 6})
Closest candidates are:
(::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat at rounding.jl:200
(::Type{T})(::T) where T<:Number at boot.jl:760
(::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number} at char.jl:50
...
Stacktrace:
[1] convert(#unused#::Type{Float64}, x::ForwardDiff.Dual{Nothing, Float64, 6})
@ Base ./number.jl:7
[2] (::ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}})(dx::ForwardDiff.Dual{Nothing, Float64, 6})
@ ChainRulesCore ~/JuliaEnvs/Zygote.jl/dev/ChainRulesCore/src/projection.jl:144
[3] ^_pullback
@ ~/JuliaEnvs/Zygote.jl/dev/ChainRules/src/rulesets/Base/fastmath_able.jl:172 [inlined] |
I suspect that the correct thing to do here is not to drop the partial part. I think the correct thing to do is to end up with more dual numbers. julia> d = ForwardDiff.Dual(1,2)
Dual{Nothing}(1,2)
julia> AbstractFloat(d)
Dual{Nothing}(1.0,2.0) So matching the behavour of |
I am now more convinced this is correct. It is kinda gross that to do so it has to overload the constructor to return the |
@YingboMa raised the good questions about how many invalidations this would cause. But checking with SnoopCompile (and remembering to
It seems |
This needs a test for when this is a problem for ForwardDiff. What bug does this fix? The test just verifies that the implementation does what the implementation does but it lacks a motivating example. |
Done. |
Late to the party, but this seems a bit odd. If I understand right what's happening is this:
Could it be fixed elsewhere? The reason Or be done through an intermediate function which means "convert the numeric type to Float64, but preserve the Dual / units / etc"? Second, this PR doesn't seem to achieve that, since the FloatN constructor doesn't always change the base numeric type. And it leads to weird error messages:
|
Right now it has some promotion stuff in there.
That error message seems fine to me.
Could be a thing. I do still think this PR makes perfectly fine sense though. From the perspective of trying to do forwards mode AD on some black box ForwardDiff is failing to AD it, because right now it doesn't know what to do when it sees a calm the |
If Crossed out above, but #508 recently made I remain a little concerned that making
Yes, I have no idea where would be the right place for such a "convert the numbers inside this" to live, if it isn't overloaded on the constructors. I do still think the idea of making
My complaint is that a failed typeassert isn't a "you the user gave be bad input" message, it's a "my understanding of what Julia can legally produce here is mistaken" message. Independent of this PR, it's possible that it should say something more helpful. This appears to be the leading way to use ForwardDiff wrong, so possibly some friendly message (explaining that your buffer needs a wide enough type) would be a good idea. But the basic fact that you cannot convert Duals to Floats because this would silently lose derivatives is essential to understanding how this package works. I don't think the non-conversion in this PR can in fact silently lose information; it seems super-important to be sure of that. |
Duals are not numbers. But anyway, that's why it's constructors don't return the right types.
Yeah, I will check tomorrow to see if removing promotion is a thing.
Out of scope for this package though.
That was me. MWE xs, y = randn(2,3), rand()
f34(xs, y) = xs[1] * (sum(xs .^ (1:3)') + y^4) # non-diagonal Hessian, two arguments
dx, dy = diaghessian(f34, xs, y)
@test size(dx) == size(xs)
@test vec(dx) ≈ diag(hessian(x -> f34(x,y), xs)) @test dy ≈ hessian(y -> f34(xs,y), y)
Indeed. |
What is special about Float64? Isn't what you saying that for any type |
Yes. My example above was Vec for SIMD, which is similarly a way of threading more information through code which was written for scalars. Another I just thought of is Measurements.jl, which FWIW does not allow such conversions: julia> λ = measurement(0, 0.1)
0.0 ± 0.1
julia> typeof(λ)
Measurement{Float64}
julia> Float32(λ)
ERROR: MethodError: no method matching Float32(::Measurement{Float64}) |
Wouldn't it be better to add |
Or better yet |
I think the definition of However, I am worried that using ForwardDiff
using Random
struct IsoNormal{V<:AbstractVector}
mu::V
end
function Random.rand!(rng::Random.AbstractRNG, x::AbstractVector, d::IsoNormal)
length(x) == length(d.mu) || throw(DimensionMismatch())
randn!(rng, x)
x .+= d.mu
return x
end
Base.rand(d::IsoNormal) = rand(Float64, d)
Base.rand(::Type{T}, d::IsoNormal) where {T} = Base.rand(Random.GLOBAL_RNG, T, d)
Base.rand(rng::Random.AbstractRNG, d::IsoNormal) = rand(rng, Float64, d)
function Base.rand(rng::AbstractRNG, ::Type{T}, d::IsoNormal) where {T}
return rand!(rng, Vector{T}(undef, length(d.mu)), d)
end
rand(IsoNormal(zeros(2))) # works, returns `Vector{Float64}`
rand(Float32, IsoNormal(zeros(2))) # works, returns `Vector{Float32}`
rand(typeof(ForwardDiff.Dual(0f0)), IsoNormal(zeros(2))) # works (but probably uncommon), returns `Vector{<:Dual}`
rand(IsoNormal([ForwardDiff.Dual(0.0) for _ in 1:2])) # fails since `Float64(d::Dual)` not defined or a `Dual` as in this PR
rand(Float32, IsoNormal([ForwardDiff.Dual(0.0) for _ in 1:2])) # fails since `Float64(d::Dual)` not defined or a `Dual` as in this PR
rand(typeof(ForwardDiff.Dual(0f0)), IsoNormal([ForwardDiff.Dual(0.0) for _ in 1:2])) # works as expected which came up in JuliaStats/Distributions.jl#1433 (it's not clear from this example but intentionally the default type |
Yes, if someone writes explicitly |
I just learnt some days ago that Tracker contains an even more general form of the approach in this PR (but, of course, it's a really nasty type piracy and it also drops the tag): FluxML/Tracker.jl#134 |
Needed for Zygote on ChainRules 1.0 (only the
Float64
case strictly speaking)Similar to the recently merged https://github.com//pull/508/files