-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LeakyReLU #81
LeakyReLU #81
Conversation
Tests fail due to type-instability in Tracker.jl (in particular: FluxML/Tracker.jl#42). The question is: do we care? I'm not sure what the current state of the "should we drop support for Tracker.jl?"-discussion is 😕 |
Some test errors on Travis are real and not Tracker-related though. Test errors on Zygote should be fixed by the next release of DistributionsAD (see TuringLang/DistributionsAD.jl#108 (comment)). |
IMO Tracker is well supported so far (in particular after dropping support for Julia < 1.3), e.g., all supported distributions in DistributionsAD work with Tracker (https://github.com/TuringLang/DistributionsAD.jl/blob/master/test/ad/distributions.jl). So if possible, I think we should support it. |
Why do we implement our own instead of using https://github.com/FluxML/NNlib.jl/blob/master/src/activation.jl#L73? |
Co-authored-by: David Widmann <[email protected]>
I had no idea this existed! It might also not have at the time when I first implemented LeakyReLU. But that's neat! I guess we could just do: (b::LeakyReLU)(x) = NNlib.leakyrelu(x, b.α)
(ib::Inverse{<:LeakyReLU})(y) = NNlib.leakyrelu(y, inv(ib.orig.α)) I'll try that 👍 Thanks! |
Of course (ib::Inverse{<:LeakyReLU})(y) = NNlib.leakyrelu(y, inv(ib.orig.α)) doesn't work. Their impl assumes that the |
|
Ah, I didn't know that Travis is the only one testing the bijectors at this point... Then there are still issues with not Tracker and Broadcasting not working togheter. |
@devmotion I see what you meant now, but there are still issues related to Tracker.jl making it type unstable. |
src/bijectors/leaky_relu.jl
Outdated
mask = x .< zero(eltype(x)) | ||
J = mask .* b.α .+ (1 .- mask) .* one(eltype(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe
mask = x .< zero(eltype(x)) | |
J = mask .* b.α .+ (1 .- mask) .* one(eltype(x)) | |
mask = x .< zero(eltype(x)) | |
J = @. (x < zero(x)) * b.α + (x > zero(x)) * one(x) |
would be more efficient (since it avoids the allocation of mask
) and more stable (since it does not require computations based on types).
As a side remark, IMO it is a bit unfortunate that currently so many almost identical implementations of a function are needed to define a bijector. Maybe this could be resolved by defining defaults on a high level for basically all bijectors (or a subtype of BaseBijectors, similar to BaseKernels in KernelFunctions) that call just in one or two methods that users would have to implement if their bijectors belong to these simple standard groups. With Julia >= 1.3 it is also possible to define methods for abstract functions such as
(f::Bijector{T,0})(x::AbstractVecOrMat{<:Real}) where T = map(f, x)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be more efficient (since it avoids the allocation of mask) and more stable (since it does not require computations based on types).
Won't this call zero(x)
n times though, i.e. have the same allocation? Other than that, it seems like a good idea 👍
As a side remark
100%.
Maybe this could be resolved by defining defaults on a high level for basically all bijectors (or a subtype of BaseBijectors, similar to BaseKernels in KernelFunctions) that call just in one or two methods that users would have to implement if their bijectors belong to these simple standard groups.
I'm unaware of what KernelFunctions.jl does. Are they simply not making the struct
callable? If so, it was a deliberate design-decision to go with making structs callable when we started out. We were debating whether or not to do this or include some transform
method so we could define "one method to rule them all" on an abstract type. Ended up going with the "callable struct" approach, but it def has it's issues, i.e. redundant code.
I've recently played around a bit with actually using a transform
method under the hood, but adding a macro which allows you to say @deftransform function transform(b, x) ... end
and we just add a (b::Bijector)(x) = transform(b, x)
just after the method declaration. This would also allow us to implement in-place versions of all the methods, i.e. transform!(b, x, out)
, logabsdetjac(b, x, out)
, and so on. Thoughts?
With Julia >= 1.3 it is also possible to define methods for abstract functions such as
Woah, really? If so that would solve the problem, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm unaware of what KernelFunctions.jl does. Are they simply not making the struct callable? If so, it was a deliberate design-decision to go with making structs callable when we started out. We were debating whether or not to do this or include some transform method so we could define "one method to rule them all" on an abstract type. Ended up going with the "callable struct" approach, but it def has it's issues, i.e. redundant code.
No, KernelFunctions API is based on callable structs only (using a function was removed a while ago). But e.g. translation-invariant kernels are built all in the same way, usually they use some simple function to evaluate the distance between the inputs (e.g. using Distances) and then they apply some nonlinear mapping afterwards. With this special structure also all kind of optimizations are possible when constructing kernel matrices etc. Hence there is a special type of such kernels (SimpleKernel IIRC), and then users just define their metric and the nonlinear mapping and get everything else for free. There's more information here: https://juliagaussianprocesses.github.io/KernelFunctions.jl/dev/create_kernel/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To add on to the above, because it sucks to have to implement (b::Bijector)
, logabsdetjac
and then finally forward
, would it be an idea to add a macro that would allow you to define all in one go?
EDIT: Just see #137 :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But e.g. translation-invariant kernels are built all in the same way, usually they use some simple function to evaluate the distance between the inputs (e.g. using Distances) and then they apply some nonlinear mapping afterwards.
Ah, I see. I don't think it's so easy to do for bijectors as we have much more structure than just "forward" evaluation, no?
Co-authored-by: David Widmann <[email protected]>
Are you sure that this causes the test error? The linked issue seems to be about broadcasting issues with Tracker if the method is not defined generically enough that result in a MethodError (if Tracker doesn't know how to handle it, it falls back to ForwardDiff - but this requires that the method allows for inputs of type ForwardDiff.Dual) whereas here it seems to be an inference problem. |
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
Looking closer, it seems like you're right. It might not be related after all 😕 |
Co-authored-by: David Widmann <[email protected]>
I had a closer look at the Tracker issue and made the following interesting observation: julia> using Tracker, Test
julia> f(x::AbstractVecOrMat) = @. x < zero(x)
julia> @inferred(f(Tracker.param(rand(5,4))))
ERROR: return type BitArray{2} does not match inferred return type Any
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] top-level scope at REPL[21]:1
julia> f(x::AbstractVecOrMat) = @. x < 0
julia> @inferred(f(Tracker.param(rand(5,4))))
5×4 BitArray{2}:
0 0 0 0
0 0 0 0
0 0 0 0
0 0 0 0
0 0 0 0
julia> f(x::AbstractVecOrMat) = @. x < $(zero(eltype(x)))
julia> @inferred(f(Tracker.param(rand(5,4))))
5×4 BitArray{2}:
0 0 0 0
0 0 0 0
0 0 0 0
0 0 0 0
0 0 0 0
julia> g(x) = x < zero(x)
g (generic function with 1 method)
julia> f(x::AbstractVecOrMat) = g.(x)
f (generic function with 2 methods)
julia> @inferred(f(Tracker.param(rand(5,4))))
5×4 BitArray{2}:
0 0 0 0
0 0 0 0
0 0 0 0
0 0 0 0
0 0 0 0 I don't know why the first approach does not work but I guess the Tracker could be fixed by defining two functions J = @. isnegative(x) * a * x + ispositive(x) * x instead of J = @. (x < zero(x)) * a * x + (x > zero(x)) * x |
This is some strange stuff: julia> f(x::AbstractVecOrMat) = @. x < zero(x)
julia> @code_warntype f(x)
Variables
#self#::Core.Compiler.Const(f, false)
x::TrackedArray{…,Array{Float64,2}}
Body::Any
1 ─ %1 = Base.broadcasted(Main.zero, x)::Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(zero),Tuple{TrackedArray{…,Array{Float64,2}}}}
│ %2 = Base.broadcasted(Main.:<, x, %1)::Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(<),Tuple{TrackedArray{…,Array{Float64,2}},Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(zero),Tuple{TrackedArray{…,Array{Float64,2}}}}}}
│ %3 = Base.materialize(%2)::Any
└── return %3
julia> g(x::AbstractVecOrMat) = @. x < 0
g (generic function with 1 method)
julia> @code_warntype g(x)
Variables
#self#::Core.Compiler.Const(g, false)
x::TrackedArray{…,Array{Float64,2}}
Body::BitArray{2}
1 ─ %1 = Base.broadcasted(Main.:<, x, 0)::Core.Compiler.PartialStruct(Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(<),Tuple{TrackedArray{…,Array{Float64,2}},Int64}}, Any[Core.Compiler.Const(<, false), Core.Compiler.PartialStruct(Tuple{TrackedArray{…,Array{Float64,2}},Int64}, Any[TrackedArray{…,Array{Float64,2}}, Core.Compiler.Const(0, false)]), Core.Compiler.Const(nothing, false)])
│ %2 = Base.materialize(%1)::BitArray{2}
└── return %2
julia> function h(x::AbstractVecOrMat)
z = zero(x)
return @. x < z
end
h (generic function with 1 method)
julia> @code_warntype h(x)
Variables
#self#::Core.Compiler.Const(h, false)
x::TrackedArray{…,Array{Float64,2}}
z::Array{Float64,2}
Body::BitArray{2}
1 ─ (z = Main.zero(x))
│ %2 = Base.broadcasted(Main.:<, x, z)::Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(<),Tuple{TrackedArray{…,Array{Float64,2}},Array{Float64,2}}}
│ %3 = Base.materialize(%2)::BitArray{2}
└── return %3 Sooo I think we fix it without custom-adjoints, i.e. by just not broadcasting the |
src/bijectors/leaky_relu.jl
Outdated
|
||
# Batched version | ||
function forward(b::LeakyReLU{<:Any, 0}, x::AbstractVector) | ||
J = let z = zero(x), o = one(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO that's not a good fix since it destroys the whole point of just having one broadcast expression that is fused together. Now we would allocate two additional vectors of ones and zeros just for Tracker.
Did you check if we can work around the Tracker bug by using a separate function ispositive(x) = x > zero(x)
and isnegative(x) = x < zero(x)
that we can include in the broadcast expression instead of the explicit x < zero(x)
etc statements?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now we would allocate two additional vectors of ones and zeros just for Tracker.
Not two vectors; just two numbers, right? (I realize the above impl was wrong, it was supposed to contain eltype
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, even if we use isnegative
, we're running into the same issue because we're broadcasting over one(x)
, so we would have to also have this in a separate function, e.g. mul1(x, y) = x * one(y)
or something 😕
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With eltype
it's just a number - IMO it's still unfortunate since there would be no need for this type-based computation if Tracker would be fixed (I opened a PR: FluxML/Tracker.jl#85). In general it's better to not use types if not needed since the instances contain more information (similar argument for why generated functions should be used only if needed). It should just work with this single broadcast expression here.
Apart from that, are the let
blocks actually needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW it seems the tests still fail?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It still breaks, but now the only one failing is LeakyReLU{..., 1}
. Again it hits a snag at materlialize
:
Variables
b::LeakyReLU{Float64,1}
x::TrackedArray{…,Array{Float64,2}}
z::Tracker.TrackedReal{Float64}
Body::Any
1 ─ %1 = Bijectors.eltype(x)::Core.Compiler.Const(Tracker.TrackedReal{Float64}, false)
│ (z = Bijectors.zero(%1))
│ %3 = Base.broadcasted(Bijectors.:<, x, z)::Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(<),Tuple{TrackedArray{…,Array{Float64,2}},Tracker.TrackedReal{Float64}}}
│ %4 = Base.getproperty(b, :α)::Float64
│ %5 = Base.broadcasted(Bijectors.:*, %3, %4)::Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(*),Tuple{Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(<),Tuple{TrackedArray{…,Array{Float64,2}},Tracker.TrackedReal{Float64}}},Float64}}
│ %6 = Base.broadcasted(Bijectors.:>, x, z)::Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(>),Tuple{TrackedArray{…,Array{Float64,2}},Tracker.TrackedReal{Float64}}}
│ %7 = Base.broadcasted(Bijectors.:*, %6, x)::Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(*),Tuple{Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(>),Tuple{TrackedArray{…,Array{Float64,2}},Tracker.TrackedReal{Float64}}},TrackedArray{…,Array{Float64,2}}}}
│ %8 = Base.broadcasted(Bijectors.:+, %5, %7)::Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(+),Tuple{Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(*),Tuple{Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(<),Tuple{TrackedArray{…,Array{Float64,2}},Tracker.TrackedReal{Float64}}},Float64}},Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(*),Tuple{Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(>),Tuple{TrackedArray{…,Array{Float64,2}},Tracker.TrackedReal{Float64}}},TrackedArray{…,Array{Float64,2}}}}}}
│ %9 = Base.materialize(%8)::Any
└── return %9
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO it's still unfortunate since there would be no need for this type-based computation if Tracker would be fixed (I opened a PR: FluxML/Tracker.jl#85)
100% agree! And nice!
Apart from that, are the let blocks actually needed?
Nah, I just added them to make explicit that it's only needed for this particular statement, and not to clutter the names in the full function. Was nicer IMO, but ofc subjective.
Codecov Report
@@ Coverage Diff @@
## master #81 +/- ##
==========================================
+ Coverage 53.30% 54.23% +0.93%
==========================================
Files 25 26 +1
Lines 1606 1641 +35
==========================================
+ Hits 856 890 +34
- Misses 750 751 +1
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Seems like David already did the hard work. I left a minor test comment, but I couldn't find anything in the actual code that seems incorrect or lacking.
using Bijectors: LeakyReLU | ||
|
||
using LinearAlgebra | ||
using ForwardDiff |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think my only real comment here is on testing other AD backends like ReverseDiff, but I'm not sure how important that is here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some tests for it in test/bijectors/interface.jl
. But yeah, testing in Bijectors is honestly a bit of mess atm. In a couple of the other PRs I've added some functionality which makes it easier to use a "standardized" testing suite for a new Bijector
, so the plan is to use that in the future 👍
Implementation of LeakyReLU as a
Bijector
, which defines the invertible mappingwhere α > 0.
TODOs