Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LeakyReLU #81

Merged
merged 29 commits into from
Sep 22, 2020
Merged

LeakyReLU #81

merged 29 commits into from
Sep 22, 2020

Conversation

torfjelde
Copy link
Member

@torfjelde torfjelde commented Feb 11, 2020

Implementation of LeakyReLU as a Bijector, which defines the invertible mapping

x ↦ x if x ≥ 0 else αx

where α > 0.

TODOs

  • Clean up testing a bit and make sure we've captured everything.

@torfjelde
Copy link
Member Author

Tests fail due to type-instability in Tracker.jl (in particular: FluxML/Tracker.jl#42).

The question is: do we care? I'm not sure what the current state of the "should we drop support for Tracker.jl?"-discussion is 😕

@devmotion
Copy link
Member

Some test errors on Travis are real and not Tracker-related though. Test errors on Zygote should be fixed by the next release of DistributionsAD (see TuringLang/DistributionsAD.jl#108 (comment)).

@devmotion
Copy link
Member

The question is: do we care? I'm not sure what the current state of the "should we drop support for Tracker.jl?"-discussion is

IMO Tracker is well supported so far (in particular after dropping support for Julia < 1.3), e.g., all supported distributions in DistributionsAD work with Tracker (https://github.com/TuringLang/DistributionsAD.jl/blob/master/test/ad/distributions.jl). So if possible, I think we should support it.

@xukai92
Copy link
Member

xukai92 commented Sep 10, 2020

Why do we implement our own instead of using https://github.com/FluxML/NNlib.jl/blob/master/src/activation.jl#L73?

@torfjelde
Copy link
Member Author

Why do we implement our own instead of using https://github.com/FluxML/NNlib.jl/blob/master/src/activation.jl#L73?

I had no idea this existed! It might also not have at the time when I first implemented LeakyReLU. But that's neat! I guess we could just do:

(b::LeakyReLU)(x) = NNlib.leakyrelu(x, b.α)
(ib::Inverse{<:LeakyReLU})(y) = NNlib.leakyrelu(y, inv(ib.orig.α))

I'll try that 👍 Thanks!

@torfjelde
Copy link
Member Author

Of course

(ib::Inverse{<:LeakyReLU})(y) = NNlib.leakyrelu(y, inv(ib.orig.α))

doesn't work. Their impl assumes that the α ∈ (0, 1). So this would mean that we could only make use of their impl for the forward pass, which is not that useful, right?

@torfjelde
Copy link
Member Author

torfjelde commented Sep 11, 2020

Tests fail due to type-instability in Tracker.jl (in particular: FluxML/Tracker.jl#42).

So it seems like this doesn't fail after all! It might have been me just using a global scope, whoops...
See below.

@torfjelde torfjelde changed the title [WIP] LeakyReLU LeakyReLU Sep 11, 2020
@torfjelde
Copy link
Member Author

Ah, I didn't know that Travis is the only one testing the bijectors at this point... Then there are still issues with not Tracker and Broadcasting not working togheter.

@torfjelde
Copy link
Member Author

@devmotion I see what you meant now, but there are still issues related to Tracker.jl making it type unstable.

Comment on lines 51 to 52
mask = x .< zero(eltype(x))
J = mask .* b.α .+ (1 .- mask) .* one(eltype(x))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe

Suggested change
mask = x .< zero(eltype(x))
J = mask .* b.α .+ (1 .- mask) .* one(eltype(x))
mask = x .< zero(eltype(x))
J = @. (x < zero(x)) * b.α + (x > zero(x)) * one(x)

would be more efficient (since it avoids the allocation of mask) and more stable (since it does not require computations based on types).

As a side remark, IMO it is a bit unfortunate that currently so many almost identical implementations of a function are needed to define a bijector. Maybe this could be resolved by defining defaults on a high level for basically all bijectors (or a subtype of BaseBijectors, similar to BaseKernels in KernelFunctions) that call just in one or two methods that users would have to implement if their bijectors belong to these simple standard groups. With Julia >= 1.3 it is also possible to define methods for abstract functions such as

(f::Bijector{T,0})(x::AbstractVecOrMat{<:Real}) where T = map(f, x)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be more efficient (since it avoids the allocation of mask) and more stable (since it does not require computations based on types).

Won't this call zero(x) n times though, i.e. have the same allocation? Other than that, it seems like a good idea 👍

As a side remark

100%.

Maybe this could be resolved by defining defaults on a high level for basically all bijectors (or a subtype of BaseBijectors, similar to BaseKernels in KernelFunctions) that call just in one or two methods that users would have to implement if their bijectors belong to these simple standard groups.

I'm unaware of what KernelFunctions.jl does. Are they simply not making the struct callable? If so, it was a deliberate design-decision to go with making structs callable when we started out. We were debating whether or not to do this or include some transform method so we could define "one method to rule them all" on an abstract type. Ended up going with the "callable struct" approach, but it def has it's issues, i.e. redundant code.

I've recently played around a bit with actually using a transform method under the hood, but adding a macro which allows you to say @deftransform function transform(b, x) ... end and we just add a (b::Bijector)(x) = transform(b, x) just after the method declaration. This would also allow us to implement in-place versions of all the methods, i.e. transform!(b, x, out), logabsdetjac(b, x, out), and so on. Thoughts?

With Julia >= 1.3 it is also possible to define methods for abstract functions such as

Woah, really? If so that would solve the problem, no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm unaware of what KernelFunctions.jl does. Are they simply not making the struct callable? If so, it was a deliberate design-decision to go with making structs callable when we started out. We were debating whether or not to do this or include some transform method so we could define "one method to rule them all" on an abstract type. Ended up going with the "callable struct" approach, but it def has it's issues, i.e. redundant code.

No, KernelFunctions API is based on callable structs only (using a function was removed a while ago). But e.g. translation-invariant kernels are built all in the same way, usually they use some simple function to evaluate the distance between the inputs (e.g. using Distances) and then they apply some nonlinear mapping afterwards. With this special structure also all kind of optimizations are possible when constructing kernel matrices etc. Hence there is a special type of such kernels (SimpleKernel IIRC), and then users just define their metric and the nonlinear mapping and get everything else for free. There's more information here: https://juliagaussianprocesses.github.io/KernelFunctions.jl/dev/create_kernel/

Copy link
Member Author

@torfjelde torfjelde Sep 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To add on to the above, because it sucks to have to implement (b::Bijector), logabsdetjac and then finally forward, would it be an idea to add a macro that would allow you to define all in one go?

EDIT: Just see #137 :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But e.g. translation-invariant kernels are built all in the same way, usually they use some simple function to evaluate the distance between the inputs (e.g. using Distances) and then they apply some nonlinear mapping afterwards.

Ah, I see. I don't think it's so easy to do for bijectors as we have much more structure than just "forward" evaluation, no?

@devmotion
Copy link
Member

Okay, so now the only issue is FluxML/Tracker.jl#42

Are you sure that this causes the test error? The linked issue seems to be about broadcasting issues with Tracker if the method is not defined generically enough that result in a MethodError (if Tracker doesn't know how to handle it, it falls back to ForwardDiff - but this requires that the method allows for inputs of type ForwardDiff.Dual) whereas here it seems to be an inference problem.

torfjelde and others added 2 commits September 12, 2020 16:49
@torfjelde
Copy link
Member Author

Are you sure that this causes the test error?

Looking closer, it seems like you're right. It might not be related after all 😕

@devmotion
Copy link
Member

I had a closer look at the Tracker issue and made the following interesting observation:

julia> using Tracker, Test

julia> f(x::AbstractVecOrMat) = @. x < zero(x)

julia> @inferred(f(Tracker.param(rand(5,4))))
ERROR: return type BitArray{2} does not match inferred return type Any
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] top-level scope at REPL[21]:1

julia> f(x::AbstractVecOrMat) = @. x < 0

julia> @inferred(f(Tracker.param(rand(5,4))))
5×4 BitArray{2}:
 0  0  0  0
 0  0  0  0
 0  0  0  0
 0  0  0  0
 0  0  0  0

julia> f(x::AbstractVecOrMat) = @. x < $(zero(eltype(x)))

julia> @inferred(f(Tracker.param(rand(5,4))))
5×4 BitArray{2}:
 0  0  0  0
 0  0  0  0
 0  0  0  0
 0  0  0  0
 0  0  0  0

julia> g(x) = x < zero(x)
g (generic function with 1 method)

julia> f(x::AbstractVecOrMat) = g.(x)
f (generic function with 2 methods)

julia> @inferred(f(Tracker.param(rand(5,4))))
5×4 BitArray{2}:
 0  0  0  0
 0  0  0  0
 0  0  0  0
 0  0  0  0
 0  0  0  0

I don't know why the first approach does not work but I guess the Tracker could be fixed by defining two functions isnegative(x) = x < zero(x) and ispositive(x) = x > zero(x), and writing

J = @. isnegative(x) * a * x + ispositive(x) * x

instead of

J = @. (x < zero(x)) * a * x + (x > zero(x)) * x

@torfjelde
Copy link
Member Author

This is some strange stuff:

julia> f(x::AbstractVecOrMat) = @. x < zero(x)

julia> @code_warntype f(x)
Variables
  #self#::Core.Compiler.Const(f, false)
  x::TrackedArray{…,Array{Float64,2}}

Body::Any
1%1 = Base.broadcasted(Main.zero, x)::Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(zero),Tuple{TrackedArray{,Array{Float64,2}}}}
│   %2 = Base.broadcasted(Main.:<, x, %1)::Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(<),Tuple{TrackedArray{,Array{Float64,2}},Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(zero),Tuple{TrackedArray{,Array{Float64,2}}}}}}
│   %3 = Base.materialize(%2)::Any
└──      return %3

julia> g(x::AbstractVecOrMat) = @. x < 0
g (generic function with 1 method)

julia> @code_warntype g(x)
Variables
  #self#::Core.Compiler.Const(g, false)
  x::TrackedArray{…,Array{Float64,2}}

Body::BitArray{2}
1%1 = Base.broadcasted(Main.:<, x, 0)::Core.Compiler.PartialStruct(Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(<),Tuple{TrackedArray{,Array{Float64,2}},Int64}}, Any[Core.Compiler.Const(<, false), Core.Compiler.PartialStruct(Tuple{TrackedArray{,Array{Float64,2}},Int64}, Any[TrackedArray{,Array{Float64,2}}, Core.Compiler.Const(0, false)]), Core.Compiler.Const(nothing, false)])
│   %2 = Base.materialize(%1)::BitArray{2}
└──      return %2

julia> function h(x::AbstractVecOrMat)
           z = zero(x)
           return @. x < z
       end
h (generic function with 1 method)

julia> @code_warntype h(x)
Variables
  #self#::Core.Compiler.Const(h, false)
  x::TrackedArray{…,Array{Float64,2}}
  z::Array{Float64,2}

Body::BitArray{2}
1 ─      (z = Main.zero(x))
│   %2 = Base.broadcasted(Main.:<, x, z)::Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(<),Tuple{TrackedArray{,Array{Float64,2}},Array{Float64,2}}}
│   %3 = Base.materialize(%2)::BitArray{2}
└──      return %3

Sooo I think we fix it without custom-adjoints, i.e. by just not broadcasting the zero.


# Batched version
function forward(b::LeakyReLU{<:Any, 0}, x::AbstractVector)
J = let z = zero(x), o = one(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO that's not a good fix since it destroys the whole point of just having one broadcast expression that is fused together. Now we would allocate two additional vectors of ones and zeros just for Tracker.

Did you check if we can work around the Tracker bug by using a separate function ispositive(x) = x > zero(x) and isnegative(x) = x < zero(x) that we can include in the broadcast expression instead of the explicit x < zero(x) etc statements?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now we would allocate two additional vectors of ones and zeros just for Tracker.

Not two vectors; just two numbers, right? (I realize the above impl was wrong, it was supposed to contain eltype)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, even if we use isnegative, we're running into the same issue because we're broadcasting over one(x), so we would have to also have this in a separate function, e.g. mul1(x, y) = x * one(y) or something 😕

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With eltype it's just a number - IMO it's still unfortunate since there would be no need for this type-based computation if Tracker would be fixed (I opened a PR: FluxML/Tracker.jl#85). In general it's better to not use types if not needed since the instances contain more information (similar argument for why generated functions should be used only if needed). It should just work with this single broadcast expression here.

Apart from that, are the let blocks actually needed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW it seems the tests still fail?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It still breaks, but now the only one failing is LeakyReLU{..., 1}. Again it hits a snag at materlialize:

Variables
  b::LeakyReLU{Float64,1}
  x::TrackedArray{…,Array{Float64,2}}
  z::Tracker.TrackedReal{Float64}

Body::Any
1%1 = Bijectors.eltype(x)::Core.Compiler.Const(Tracker.TrackedReal{Float64}, false)
│        (z = Bijectors.zero(%1))
│   %3 = Base.broadcasted(Bijectors.:<, x, z)::Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(<),Tuple{TrackedArray{,Array{Float64,2}},Tracker.TrackedReal{Float64}}}
│   %4 = Base.getproperty(b, )::Float64%5 = Base.broadcasted(Bijectors.:*, %3, %4)::Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(*),Tuple{Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(<),Tuple{TrackedArray{,Array{Float64,2}},Tracker.TrackedReal{Float64}}},Float64}}
│   %6 = Base.broadcasted(Bijectors.:>, x, z)::Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(>),Tuple{TrackedArray{,Array{Float64,2}},Tracker.TrackedReal{Float64}}}
│   %7 = Base.broadcasted(Bijectors.:*, %6, x)::Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(*),Tuple{Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(>),Tuple{TrackedArray{,Array{Float64,2}},Tracker.TrackedReal{Float64}}},TrackedArray{,Array{Float64,2}}}}
│   %8 = Base.broadcasted(Bijectors.:+, %5, %7)::Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(+),Tuple{Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(*),Tuple{Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(<),Tuple{TrackedArray{,Array{Float64,2}},Tracker.TrackedReal{Float64}}},Float64}},Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(*),Tuple{Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(>),Tuple{TrackedArray{,Array{Float64,2}},Tracker.TrackedReal{Float64}}},TrackedArray{,Array{Float64,2}}}}}}
│   %9 = Base.materialize(%8)::Any
└──      return %9

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it's still unfortunate since there would be no need for this type-based computation if Tracker would be fixed (I opened a PR: FluxML/Tracker.jl#85)

100% agree! And nice!

Apart from that, are the let blocks actually needed?

Nah, I just added them to make explicit that it's only needed for this particular statement, and not to clutter the names in the full function. Was nicer IMO, but ofc subjective.

@codecov
Copy link

codecov bot commented Sep 22, 2020

Codecov Report

Merging #81 into master will increase coverage by 0.93%.
The diff coverage is 94.28%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #81      +/-   ##
==========================================
+ Coverage   53.30%   54.23%   +0.93%     
==========================================
  Files          25       26       +1     
  Lines        1606     1641      +35     
==========================================
+ Hits          856      890      +34     
- Misses        750      751       +1     
Impacted Files Coverage Δ
src/interface.jl 63.88% <ø> (+2.77%) ⬆️
src/bijectors/leaky_relu.jl 94.28% <94.28%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 84314e7...410166b. Read the comment docs.

Copy link
Member

@cpfiffer cpfiffer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Seems like David already did the hard work. I left a minor test comment, but I couldn't find anything in the actual code that seems incorrect or lacking.

using Bijectors: LeakyReLU

using LinearAlgebra
using ForwardDiff
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my only real comment here is on testing other AD backends like ReverseDiff, but I'm not sure how important that is here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some tests for it in test/bijectors/interface.jl. But yeah, testing in Bijectors is honestly a bit of mess atm. In a couple of the other PRs I've added some functionality which makes it easier to use a "standardized" testing suite for a new Bijector, so the plan is to use that in the future 👍

@torfjelde torfjelde merged commit f48c1fe into master Sep 22, 2020
@delete-merged-branch delete-merged-branch bot deleted the tor/leaky-relu branch September 22, 2020 20:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants