Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix method ambiguity errors with PDMats 0.10 #81

Merged
merged 5 commits into from
Aug 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 5 additions & 68 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@ version = "2.0.2"
[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[CEnum]]
git-tree-sha1 = "215a9aa4a1f23fbd05b92769fdd62559488d70e9"
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
version = "0.4.1"

[[CommonSubexpressions]]
deps = ["MacroTools", "Test"]
git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7"
Expand All @@ -26,12 +21,6 @@ git-tree-sha1 = "7c4f882c41faa72118841185afc58a2eb00ef612"
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "0.3.3+0"

[[CpuId]]
deps = ["Markdown", "Test"]
git-tree-sha1 = "f0464e499ab9973b43c20f8216d088b61fda80c6"
uuid = "adafc99b-e345-5852-983c-f28acb93d879"
version = "0.2.2"

[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand All @@ -52,12 +41,6 @@ version = "1.0.1"
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[DocStringExtensions]]
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
git-tree-sha1 = "c5714d9bcdba66389612dc4c47ed827c64112997"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.8.2"

[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"]
git-tree-sha1 = "1d090099fb82223abc48f7ce176d3f7696ede36d"
Expand All @@ -68,12 +51,6 @@ version = "0.10.12"
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[LLVM]]
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "a662366a5d485dee882077e8da3e1a95a86d097f"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "2.0.0"

[[LibGit2]]
deps = ["Printf"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
Expand All @@ -88,12 +65,6 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[LoopVectorization]]
deps = ["DocStringExtensions", "LinearAlgebra", "OffsetArrays", "SIMDPirates", "SLEEFPirates", "UnPack", "VectorizationBase"]
git-tree-sha1 = "b595e15d20e45d2eb36c6b4462d2a34143872a45"
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
version = "0.8.15"

[[MacroTools]]
deps = ["Markdown", "Random"]
git-tree-sha1 = "f7d2e3f654af75f01ec49be82c231c382214223a"
Expand All @@ -104,27 +75,16 @@ version = "0.5.5"
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[NNPACK_jll]]
deps = ["Libdl", "Pkg"]
git-tree-sha1 = "c3d1a616362645754b18e12dbba96ec311b0867f"
uuid = "a6bfbf70-4841-5cb9-aa18-3a8ad3c413ee"
version = "2018.6.22+0"

[[NNlib]]
deps = ["Libdl", "LinearAlgebra", "LoopVectorization", "NNPACK_jll", "Pkg", "Requires", "Statistics"]
git-tree-sha1 = "f593bdb98b00a4f5b87cc2c18231b81433111590"
deps = ["Libdl", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
git-tree-sha1 = "8ec4693a5422f0b064ce324f59351f24aa474893"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.7.2"
version = "0.7.4"

[[NaNMath]]
git-tree-sha1 = "928b8ca9b2791081dc71a51c55347c27c618760f"
git-tree-sha1 = "c84c576296d0e2fbb3fc134d3e09086b3ea617cd"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "0.3.3"

[[OffsetArrays]]
git-tree-sha1 = "4ba4cd84c88df8340da1c3e2d8dcb9d18dd1b53b"
uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
version = "1.1.1"
version = "0.3.4"

[[OpenSpecFun_jll]]
deps = ["CompilerSupportLibraries_jll", "Libdl", "Pkg"]
Expand Down Expand Up @@ -157,18 +117,6 @@ version = "1.0.1"
[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

[[SIMDPirates]]
deps = ["VectorizationBase"]
git-tree-sha1 = "dae629e96c1819d77882256e6cb29736f493bc30"
uuid = "21efa798-c60a-11e8-04d3-e1a92915a26a"
version = "0.8.13"

[[SLEEFPirates]]
deps = ["Libdl", "SIMDPirates", "VectorizationBase"]
git-tree-sha1 = "c750d618b7c8268a97e55c70e8c88e56080d30fa"
uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa"
version = "0.5.4"

[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

Expand Down Expand Up @@ -203,16 +151,5 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[[UnPack]]
git-tree-sha1 = "d4bfa022cd30df012700cf380af2141961bb3bfb"
uuid = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
version = "1.0.1"

[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[[VectorizationBase]]
deps = ["CpuId", "LLVM", "Libdl", "LinearAlgebra"]
git-tree-sha1 = "bb72c58beab6c9e544851f5373fcd72f8f1f157a"
uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
version = "0.12.21"
10 changes: 8 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Tracker"
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
version = "0.2.10"
version = "0.2.11"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -15,7 +15,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Adapt = "1,2"
Expand All @@ -27,3 +26,10 @@ NaNMath = "0"
Requires = "0.5, 1.0"
SpecialFunctions = "0"
julia = "1.3"

[extras]
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["PDMats", "Test"]
1 change: 1 addition & 0 deletions src/Tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ include("numeric.jl")
include("lib/real.jl")
include("lib/array.jl")
include("forward.jl")
@init @require PDMats="90014a1f-27ba-587c-ab20-58faa44d9150" include("lib/pdmats.jl")

"""
hook(f, x) -> x′
Expand Down
18 changes: 10 additions & 8 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -329,23 +329,25 @@ end
A::TrackedArray / B::TrackedArray = Tracker.track(/, A, B)
A::AbstractVecOrMat / B::TrackedArray = Tracker.track(/, A, B)
A::TrackedArray / B::AbstractVecOrMat = Tracker.track(/, A, B)
@grad function (A / B)
@grad function Base.:/(A, B)
return Tracker.data(A) / Tracker.data(B), function (Δ)
Binv = inv(B)
∇B = - Binv' * A' * Δ * Binv'
return (Δ * Binv', ∇B)
∇A = Δ / B'
∇B = - (A / B)' * ∇A
return (∇A, ∇B)
end
end

# (\) ldivide (left vec divide needs more work to resolve dispatch ambiguity)
A::TrackedArray \ B::TrackedArray = Tracker.track(\, A, B)
A::AbstractArray \ B::TrackedArray = Tracker.track(\, A, B)
A::TrackedArray \ B::AbstractVecOrMat = Tracker.track(\, A, B)
@grad function (A \ B)
A::AbstractMatrix \ B::TrackedVecOrMat = Tracker.track(\, A, B)
A::TrackedMatrix \ B::TrackedVecOrMat = Tracker.track(\, A, B)
@grad function Base.:\(A, B)
return Tracker.data(A) \ Tracker.data(B), function (Δ)
Ainv = inv(A)
∇A = - Ainv' * Δ * B' * Ainv'
return (∇A, Ainv' * Δ)
∇B = A' \ Δ
∇A = - ∇B * (A \ B)'
return (∇A, ∇B)
end
end

Expand Down
5 changes: 5 additions & 0 deletions src/lib/pdmats.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
using .PDMats

Base.:\(A::PDMat, B::TrackedVecOrMat) = Tracker.track(\, A, B)
Base.:\(A::PDiagMat, B::TrackedVecOrMat) = Tracker.track(\, A, B)
Base.:\(A::ScalMat, B::TrackedVecOrMat) = Tracker.track(\, A, B)
28 changes: 27 additions & 1 deletion test/tracker.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using Tracker, Test, NNlib
using Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint, forwarddiff
using NNlib: conv, ∇conv_data, depthwiseconv
using PDMats
using Printf: @sprintf
using LinearAlgebra: diagm, dot, LowerTriangular, norm, det, logdet, logabsdet
using LinearAlgebra: diagm, dot, LowerTriangular, norm, det, logdet, logabsdet, I
using Statistics: mean, std
using Random
# using StatsBase
Expand Down Expand Up @@ -152,7 +153,20 @@ end
@test gradtest(W -> inv(log.(W * W)), (5,5))
@test gradtest((A, B) -> A / B , (1,5), (5,5))
@test gradtest((A, B) -> log.(A * A) / exp.(B * B), (5,5), (5,5))
@test gradtest((A, B) -> A \ B, (5, 5), (5,))
@test let A=rand(5, 5)
gradtest(B -> A \ B, (5,))
end
@test let B=rand(5,)
gradtest(A -> A \ B, (5, 5))
end
@test gradtest((A, B) -> log.(A * A) \ exp.(B * B), (5,5), (5,5))
@test let A=rand(5, 5)
gradtest(B -> log.(A * A) \ exp.(B * B), (5, 5))
end
@test let B=rand(5, 5)
gradtest(A -> log.(A * A) \ exp.(B * B), (5, 5))
end

@testset "mean" begin
@test gradtest(mean, rand(2, 3))
Expand Down Expand Up @@ -432,4 +446,16 @@ end
@test back([1, 1]) == (32,)
end

@testset "PDMats" begin
B = rand(5, 5)
S = PDMat(I + B * B')
@test gradtest(A -> S / A, (5, 5))

S = PDiagMat(rand(5))
@test gradtest(A -> S / A, (5, 5))

S = ScalMat(5, rand())
@test gradtest(A -> S / A, (5, 5))
end

end #testset