-
-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Store number of labels as a field #20
Conversation
This is a breaking change. |
Codecov ReportBase: 93.93% // Head: 93.75% // Decreases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## main #20 +/- ##
==========================================
- Coverage 93.93% 93.75% -0.19%
==========================================
Files 3 3
Lines 99 112 +13
==========================================
+ Hits 93 105 +12
- Misses 6 7 +1
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
These look like a nice set of benchmarks for future use too! I wonder if we can automate or at least script them somehow. |
src/array.jl
Outdated
function Base.cat(x::OneHotLike, xs::OneHotLike...; dims::Int) | ||
if isone(dims) || any(x -> !_isonehot(x), (x, xs...)) | ||
return cat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...; dims = dims) | ||
else | ||
L = _check_nlabels(x, xs...) | ||
return OneHotArray(cat(_indices(x), _indices.(xs)...; dims = dims - 1), L) | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like vcat
/cat(...; dims=1)
is the only real regression. Cthulhu tells me that the isone
check is type unstable here because constprop doesn't pierce far enough. Base does a some shenanigans to force it to work, but I think we can just get away with overloading vcat
.
Unfortunately GH doesn't let me add a suggestion for all the lines in question, but here it is:
function Base.cat(x::OneHotLike, xs::OneHotLike...; dims::Int)
if any(x -> !_isonehot(x), (x, xs...))
return cat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...; dims = dims)
else
L = _check_nlabels(x, xs...)
return OneHotArray(cat(_indices(x), _indices.(xs)...; dims = dims - 1), L)
end
end
Base.hcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = 2)
Base.vcat(x::OneHotLike, xs::OneHotLike...) =
vcat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...)
# optimized concatenation for matrices and vectors of same parameters
...
Perf is still not on par with main, but quite a bit closer:
julia> @btime vcat($ov, $ov2)
41.972 ns (1 allocation: 80 bytes) # main
111.121 ns (3 allocations: 208 bytes)
julia> @btime vcat($om, $om2)
284.361 ns (1 allocation: 160 bytes) # main
397.585 ns (3 allocations: 384 bytes)
The extra 2 allocations come from converting both arrays to Bool arrays before concatenating. Trying to see if there's a way around that right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't think of a way to make this faster without adding a bunch of array subtype-specific code for handling GPUs etc (NB: we don't seem to test [hv]cat
in https://github.com/FluxML/OneHotArrays.jl/blob/main/test/gpu.jl?). Given this was meant to be a fallback path anyhow, it's probably fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As it turns out, this is a place where the type specification made a difference. Because we only hit this path when the number of labels matched, we were hitting the optimized Base path in all other cases. The latest commit tries to replicate some of what makes Base fast. Let me try combining that with the isone
changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay so maybe converting to Bool arrays and calling back into Base is the fastest option.
I tried your suggestion, i.e.
_cat_fallback(x::OneHotLike, xs::OneHotLike...; dims::Int) =
cat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...; dims = dims)
function Base.cat(x::OneHotLike, xs::OneHotLike...; dims::Int)
if any(x -> !_isonehot(x), (x, xs...))
return _cat_fallback(x, xs...; dims = dims)
else
L = _nlabels(x, xs...)
return OneHotArray(cat(_indices(x), _indices.(xs)...; dims = dims - 1), L)
end
end
Base.hcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = 2)
Base.vcat(x::OneHotLike, xs::OneHotLike...) = _cat_fallback(x, xs...; dims = 1)
And I got different results:
julia> @btime vcat($ov, $ov2)
31.816 ns (1 allocation: 80 bytes) # main
808.425 ns (20 allocations: 592 bytes)
julia> @btime vcat($om, $om2)
186.143 ns (1 allocation: 160 bytes) # main
1.033 μs (20 allocations: 864 bytes)
julia> versioninfo()
Julia Version 1.8.0
Commit 5544a0fab76 (2022-08-17 13:38 UTC)
Platform Info:
OS: macOS (arm64-apple-darwin21.3.0)
CPU: 8 × Apple M1 Pro
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-13.0.1 (ORCJIT, apple-m1)
Threads: 1 on 6 virtual cores
Environment:
JULIA_EDITOR = code
JULIA_NUM_THREADS =
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it because of cat_fallback
? I skip that completely in mine because dims=1
doesn't seem to constprop correctly. It's simpler to just copy the code into the vcat
method, because there we know dims=1
and can thus use a constant literal/use vcat
directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without reading closely, doesn't vcat
call cat(...; dims=Val(1))
--- might this help?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay copying the code instead of the fallback does replicate your results. I pushed this and updated the benchmarks above. Using Val
didn't seem to make much difference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4284e73 was my attempt at getting around the cost of converting by directly writing the chunks to the destination array like Base does. But it seems that a simpler, less compiler aggressive implementation like mine ends up losing out on performance still. And it turns out that your implementation is faster even with converting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't even get that far because it seemed liable to cause scalar indexing issues on GPU. The ideal performance picture would be to have two code paths: one which does something like this but only writes the non-zeros, and one which accumulates the indices of all non-zeros and writes them in one go (for GPU compat). Given that vcat
was never meant to be a fast or common method on OneHotArrays anyhow, I think we can leave further investigation for future work.
We could definitely add a regression test for these. It doesn't take long to run them, so we can do a GH PR bot. |
2691935
to
4284e73
Compare
aef2897
to
c9fcab8
Compare
I've been looking into https://github.com/tkf/BenchmarkCI.jl for other packages. This one seems like a nice place to test it. |
As suggested on Slack, storing the number of labels as a field instead of in the type results in type stable code. The original PR that I based my overhaul of one-hot arrays on mentions memory consumption as the main reason for keeping the label in the type. While it is certainly significantly less memory than the old, old Flux implementation, compared to this PR, the overhead is constant (8 bytes).
I benchmarked to see if storing in the type resulted in downstream performance improvements (a reason to reject this PR). The results are mixed, but I think any measurable performance loss is due to a slower path being taken for indexing like
ov[1:3]
. I haven't figured how to fix this, but I think it is fixable. I don't think the performance loss is fundamental to where the number of label info is stored.Benchmarking results and code
@btime
for main@btime
for PRgetindex
ov[1]
getindex
ov[1:3]
getindex
om[1, 2]
getindex
om[1:2, 3]
getindex
oa[1, 2, 3]
getindex
oa[1, :, 3]
getindex
oa[:, :, :]
size
length(ov)
size
size(oa)
vcat(ov, ov2)
hcat(ov, ov)
hcat(om, om)
vcat(om, om2)
cat(oa, oa, oa; dims = 2)
2 .* om
om .* om .+ ov
argmax
argmax(om; dims = 1)
argmax
argmax(oa; dims = 2)
onehot((rand(1:10)), (1:10))
onehotbatch((rand(1:10, 50)), (1:10))
onecold(ov)
onecold(om)
onecold(oa)
(rand(5, 10)) * om
(rand(5, 10)) * ov
(rand(5, 5)) * adjoint(om)
Created by running: