Skip to content

Commit

Permalink
re-enable tests (#41)
Browse files Browse the repository at this point in the history
* re-enable tests

* add StatsFuns as a test dependency

* Test show

* install Zygote and IRTools in tests.

* fixing zygote tests
  • Loading branch information
tpapp authored Mar 10, 2019
1 parent 16d511f commit b877bc8
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 8 deletions.
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ matrix:
# apt: # apt-get for linux
# packages:
# - gfortran
before_script:
- julia -e 'using Pkg; pkg"add Zygote#master IRTools#master"'
# before_script:
# - julia -e 'using Pkg; pkg"add Zygote#master IRTools#master"'
jobs:
include:
- stage: "Documentation"
Expand Down
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Distributions", "Documenter", "Flux", "ForwardDiff", "ReverseDiff", "Test", "StatsBase", "Zygote"]
test = ["Distributions", "Documenter", "Flux", "ForwardDiff", "ReverseDiff", "Test", "StatsBase",
"StatsFuns", "Zygote"]
2 changes: 1 addition & 1 deletion src/AD_Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Base.show(io::IO, ∇ℓ::ZygoteGradientLogDensity) = print(io, "Zygote AD wrapp
function logdensity(::Type{ValueGradient}, ∇ℓ::ZygoteGradientLogDensity, x::AbstractVector)
@unpack= ∇ℓ
y, back = Zygote.forward(_logdensity_closure(ℓ), x)
gradient = isfinite(y) ? back(Int8(1)) : zeros(typeof(y), length(y))
gradient = isfinite(y) ? back(Int8(1))[1] : zeros(typeof(y), length(y))
ValueGradient(y, gradient)
end

Expand Down
12 changes: 8 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,15 @@ end

if VERSION v"1.1.0"
# cf https://github.com/FluxML/Zygote.jl/issues/104

import Pkg # use latest versions until tagged
Pkg.add(Pkg.PackageSpec(name = "IRTools", rev = "master"))
Pkg.add(Pkg.PackageSpec(name = "Zygote", rev = "master"))
import Zygote

@testset "AD via Zygote" begin
∇ℓ = ADgradient(:Zygote, TestLogDensity())
= TestLogDensity()
∇ℓ = ADgradient(:Zygote, ℓ)
@test repr(∇ℓ) == ("Zygote AD wrapper for " * repr(ℓ))
@test dimension(∇ℓ) == 3
buffer = randn(3)
vb = ValueGradientBuffer(buffer)
Expand All @@ -324,9 +328,9 @@ if VERSION ≥ v"1.1.0"
@test logdensity(Real, ∇ℓ, x) test_logdensity(x)
@test logdensity(Value, ∇ℓ, x) Value(test_logdensity(x))
vg = ValueGradient(test_logdensity(x), test_gradient(x))
@test_skip logdensity(ValueGradient, ∇ℓ, x) vg
@test logdensity(ValueGradient, ∇ℓ, x) vg
# NOTE don't test buffer ≡, as that is not implemented for Zygote
@test_skip logdensity(vb, ∇ℓ, x) vg
@test logdensity(vb, ∇ℓ, x) vg
end
end
end
Expand Down

0 comments on commit b877bc8

Please sign in to comment.