Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store number of labels as a field #20

Merged
merged 7 commits into from
Oct 12, 2022
Merged

Conversation

darsnack
Copy link
Member

@darsnack darsnack commented Oct 6, 2022

As suggested on Slack, storing the number of labels as a field instead of in the type results in type stable code. The original PR that I based my overhaul of one-hot arrays on mentions memory consumption as the main reason for keeping the label in the type. While it is certainly significantly less memory than the old, old Flux implementation, compared to this PR, the overhead is constant (8 bytes).

I benchmarked to see if storing in the type resulted in downstream performance improvements (a reason to reject this PR). The results are mixed, but I think any measurable performance loss is due to a slower path being taken for indexing like ov[1:3]. I haven't figured how to fix this, but I think it is fixable. I don't think the performance loss is fundamental to where the number of label info is stored.

Benchmarking results and code
Test target Code @btime for main @btime for PR
getindex ov[1] 2.125 ns (0 allocations: 0 bytes) 3.041 ns (0 allocations: 0 bytes)
getindex ov[1:3] 18.579 ns (1 allocation: 64 bytes) 21.523 ns (1 allocation: 64 bytes)
getindex om[1, 2] 2.083 ns (0 allocations: 0 bytes) 3.000 ns (0 allocations: 0 bytes)
getindex om[1:2, 3] 20.979 ns (1 allocation: 64 bytes) 21.648 ns (1 allocation: 64 bytes)
getindex oa[1, 2, 3] 2.416 ns (0 allocations: 0 bytes) 3.000 ns (0 allocations: 0 bytes)
getindex oa[1, :, 3] 201.261 ns (6 allocations: 256 bytes) 24.556 ns (1 allocation: 64 bytes)
getindex oa[:, :, :] 2.125 ns (0 allocations: 0 bytes) 2.125 ns (0 allocations: 0 bytes)
size length(ov) 1.167 ns (0 allocations: 0 bytes) 2.125 ns (0 allocations: 0 bytes)
size size(oa) 2.083 ns (0 allocations: 0 bytes) 2.084 ns (0 allocations: 0 bytes)
concat vcat(ov, ov2) 31.816 ns (1 allocation: 80 bytes) 83.678 ns (3 allocations: 208 bytes)
concat hcat(ov, ov) 243.719 ns (2 allocations: 96 bytes) 17.994 ns (1 allocation: 80 bytes)
concat hcat(om, om) 258.048 ns (2 allocations: 160 bytes) 31.061 ns (1 allocation: 144 bytes)
concat vcat(om, om2) 186.143 ns (1 allocation: 160 bytes) 272.052 ns (3 allocations: 384 bytes)
concat cat(oa, oa, oa; dims = 2) 2.093 μs (51 allocations: 1.98 KiB) 2.528 μs (64 allocations: 2.66 KiB)
broadcast 2 .* om 42.886 ns (1 allocation: 496 bytes) 53.245 ns (1 allocation: 496 bytes)
broadcast om .* om .+ ov 72.852 ns (1 allocation: 496 bytes) 116.178 ns (1 allocation: 496 bytes)
argmax argmax(om; dims = 1) 34.876 ns (3 allocations: 240 bytes) 35.211 ns (3 allocations: 240 bytes)
argmax argmax(oa; dims = 2) 3.391 μs (28 allocations: 3.23 KiB) 3.828 μs (28 allocations: 3.12 KiB)
onehot onehot((rand(1:10)), (1:10)) 744.877 ns (8 allocations: 368 bytes) 3.000 ns (0 allocations: 0 bytes)
onehot onehotbatch((rand(1:10, 50)), (1:10)) 628.188 ns (4 allocations: 512 bytes) 441.919 ns (4 allocations: 528 bytes)
onecold onecold(ov) 2.778 μs (19 allocations: 1.69 KiB) 2.602 μs (19 allocations: 1.58 KiB)
onecold onecold(om) 22.590 ns (1 allocation: 96 bytes) 22.757 ns (1 allocation: 96 bytes)
onecold onecold(oa) 34.743 ns (1 allocation: 256 bytes) 34.660 ns (1 allocation: 256 bytes)
matmul (rand(5, 10)) * om 55.752 ns (1 allocation: 256 bytes) 55.923 ns (1 allocation: 256 bytes)
matmul (rand(5, 10)) * ov 3.026 μs (28 allocations: 1.92 KiB) 2.847 μs (28 allocations: 1.81 KiB)
matmul (rand(5, 5)) * adjoint(om) 53.879 ns (1 allocation: 496 bytes) 53.583 ns (1 allocation: 496 bytes)

Created by running:

using OneHotArrays
using BenchmarkTools

ov = OneHotVector(rand(1:10), 10)
ov2 = OneHotVector(rand(1:11), 11)
om = OneHotMatrix(rand(1:10, 5), 10)
om2 = OneHotMatrix(rand(1:11, 5), 11)
oa = OneHotArray(rand(1:10, 5, 5), 10)

@info "getindex"
@btime $ov[1]
@btime $ov[1:3]
@btime $om[1, 2]
@btime $om[1:2, 3]
@btime $oa[1, 2, 3]
@btime $oa[1, :, 3]
@btime $oa[:, :, :]

@info "size"
@btime length($ov)
@btime size($oa)

@info "concat"
@btime vcat($ov, $ov2)
@btime hcat($ov, $ov)
@btime hcat($om, $om)
@btime vcat($om, $om2)
@btime cat($oa, $oa, $oa; dims = 2)

@info "broadcast"
@btime 2 .* $om
@btime $om .* $om .+ $ov

@info "argmax"
@btime argmax($om; dims = 1)
@btime argmax($oa; dims = 2)

@info "onehot"
@btime onehot($(rand(1:10)), $(1:10))
@btime onehotbatch($(rand(1:10, 50)), $(1:10))

@info "onecold"
@btime onecold($ov)
@btime onecold($om)
@btime onecold($oa)

@info "matmul"
@btime $(rand(5, 10)) * $om
@btime $(rand(5, 10)) * $ov
@btime $(rand(5, 5)) * adjoint($om)
;

@darsnack
Copy link
Member Author

darsnack commented Oct 6, 2022

This is a breaking change.

@codecov-commenter
Copy link

codecov-commenter commented Oct 6, 2022

Codecov Report

Base: 93.93% // Head: 93.75% // Decreases project coverage by -0.18% ⚠️

Coverage data is based on head (afd2ac0) compared to base (69ca6fa).
Patch coverage: 87.75% of modified lines in pull request are covered.

Additional details and impacted files
@@            Coverage Diff             @@
##             main      #20      +/-   ##
==========================================
- Coverage   93.93%   93.75%   -0.19%     
==========================================
  Files           3        3              
  Lines          99      112      +13     
==========================================
+ Hits           93      105      +12     
- Misses          6        7       +1     
Impacted Files Coverage Δ
src/onehot.jl 94.44% <66.66%> (+0.15%) ⬆️
src/array.jl 91.37% <86.84%> (+0.07%) ⬆️
src/linalg.jl 100.00% <100.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@ToucheSir
Copy link
Member

These look like a nice set of benchmarks for future use too! I wonder if we can automate or at least script them somehow.

src/array.jl Outdated
Comment on lines 102 to 109
function Base.cat(x::OneHotLike, xs::OneHotLike...; dims::Int)
if isone(dims) || any(x -> !_isonehot(x), (x, xs...))
return cat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...; dims = dims)
else
L = _check_nlabels(x, xs...)
return OneHotArray(cat(_indices(x), _indices.(xs)...; dims = dims - 1), L)
end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like vcat/cat(...; dims=1) is the only real regression. Cthulhu tells me that the isone check is type unstable here because constprop doesn't pierce far enough. Base does a some shenanigans to force it to work, but I think we can just get away with overloading vcat.

Unfortunately GH doesn't let me add a suggestion for all the lines in question, but here it is:

function Base.cat(x::OneHotLike, xs::OneHotLike...; dims::Int)
  if any(x -> !_isonehot(x), (x, xs...))
    return cat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...; dims = dims)
  else
    L = _check_nlabels(x, xs...)
    return OneHotArray(cat(_indices(x), _indices.(xs)...; dims = dims - 1), L)
  end
end

Base.hcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = 2)
Base.vcat(x::OneHotLike, xs::OneHotLike...) =
  vcat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...)

# optimized concatenation for matrices and vectors of same parameters
...

Perf is still not on par with main, but quite a bit closer:

julia> @btime vcat($ov, $ov2)
  41.972 ns (1 allocation: 80 bytes) # main
  111.121 ns (3 allocations: 208 bytes)

julia> @btime vcat($om, $om2)
  284.361 ns (1 allocation: 160 bytes) # main
  397.585 ns (3 allocations: 384 bytes)

The extra 2 allocations come from converting both arrays to Bool arrays before concatenating. Trying to see if there's a way around that right now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't think of a way to make this faster without adding a bunch of array subtype-specific code for handling GPUs etc (NB: we don't seem to test [hv]cat in https://github.com/FluxML/OneHotArrays.jl/blob/main/test/gpu.jl?). Given this was meant to be a fallback path anyhow, it's probably fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As it turns out, this is a place where the type specification made a difference. Because we only hit this path when the number of labels matched, we were hitting the optimized Base path in all other cases. The latest commit tries to replicate some of what makes Base fast. Let me try combining that with the isone changes.

Copy link
Member Author

@darsnack darsnack Oct 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay so maybe converting to Bool arrays and calling back into Base is the fastest option.

I tried your suggestion, i.e.

_cat_fallback(x::OneHotLike, xs::OneHotLike...; dims::Int) =
  cat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...; dims = dims)

function Base.cat(x::OneHotLike, xs::OneHotLike...; dims::Int)
  if any(x -> !_isonehot(x), (x, xs...))
    return _cat_fallback(x, xs...; dims = dims)
  else
    L = _nlabels(x, xs...)

    return OneHotArray(cat(_indices(x), _indices.(xs)...; dims = dims - 1), L)
  end
end

Base.hcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = 2)
Base.vcat(x::OneHotLike, xs::OneHotLike...) = _cat_fallback(x, xs...; dims = 1)

And I got different results:

julia> @btime vcat($ov, $ov2)
  31.816 ns (1 allocation: 80 bytes) # main
  808.425 ns (20 allocations: 592 bytes)

julia> @btime vcat($om, $om2)
  186.143 ns (1 allocation: 160 bytes) # main
  1.033 μs (20 allocations: 864 bytes)

julia> versioninfo()
Julia Version 1.8.0
Commit 5544a0fab76 (2022-08-17 13:38 UTC)
Platform Info:
  OS: macOS (arm64-apple-darwin21.3.0)
  CPU: 8 × Apple M1 Pro
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, apple-m1)
  Threads: 1 on 6 virtual cores
Environment:
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it because of cat_fallback? I skip that completely in mine because dims=1 doesn't seem to constprop correctly. It's simpler to just copy the code into the vcat method, because there we know dims=1 and can thus use a constant literal/use vcat directly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without reading closely, doesn't vcat call cat(...; dims=Val(1)) --- might this help?

Copy link
Member Author

@darsnack darsnack Oct 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay copying the code instead of the fallback does replicate your results. I pushed this and updated the benchmarks above. Using Val didn't seem to make much difference.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4284e73 was my attempt at getting around the cost of converting by directly writing the chunks to the destination array like Base does. But it seems that a simpler, less compiler aggressive implementation like mine ends up losing out on performance still. And it turns out that your implementation is faster even with converting.

Copy link
Member

@ToucheSir ToucheSir Oct 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't even get that far because it seemed liable to cause scalar indexing issues on GPU. The ideal performance picture would be to have two code paths: one which does something like this but only writes the non-zeros, and one which accumulates the indices of all non-zeros and writes them in one go (for GPU compat). Given that vcat was never meant to be a fast or common method on OneHotArrays anyhow, I think we can leave further investigation for future work.

@darsnack
Copy link
Member Author

These look like a nice set of benchmarks for future use too! I wonder if we can automate or at least script them somehow.

We could definitely add a regression test for these. It doesn't take long to run them, so we can do a GH PR bot.

@ToucheSir
Copy link
Member

I've been looking into https://github.com/tkf/BenchmarkCI.jl for other packages. This one seems like a nice place to test it.

@darsnack darsnack merged commit ddbba63 into FluxML:main Oct 12, 2022
@darsnack darsnack deleted the length-as-field branch October 12, 2022 13:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants