Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use fallback for reshape/cat OneHotArray #1459

Merged
merged 5 commits into from
Jan 29, 2021

Conversation

darsnack
Copy link
Member

@darsnack darsnack commented Jan 9, 2021

This falls back to reshaping a Bool array whenever reshaping the first dimension of a OneHotArray.

@DhairyaLGandhi @CarloLucibello @simeonschaub

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

@simeonschaub
Copy link
Member

simeonschaub commented Jan 9, 2021

Sorry, if this is a dumb question, but I didn't quite understand the reasoning in #1448 why the default behavior for AbstractArrays in Base was bad.

This means that if you collect the lazy iterator, you'll currently get an Array{Bool} when we want to return a OneHotArray (see the failing test cases).

Isn't this what collect is supposed to do? Or may this even be fixed by JuliaLang/julia#38965?

@darsnack
Copy link
Member Author

darsnack commented Jan 9, 2021

That actually wasn't the reason for throwing the error. I guess the actual reason might have been that when you reshape the first dimension, the array is no longer one-hot, so we erred on the side of alerting the user instead of silently converting (the same was done for cating along the first dimension).

With respect to the part you quoted, that is referring to the non-error branch where the first dimension is the same and the other dimensions are reshaped. In this case, the array is still a OneHotArray, so we'd like to return that type. But, due to JuliaLang/julia#39123, we get a ReshapedArray out. Collecting the ReshapedArray will index the parent array by element which returns a Bool resulting in an Array{Bool} out instead of the OneHotArray we wanted.

But I do think JuliaLang/julia#38965 might allow us to fix this. I can test it out on this branch and commit it if that's the desired behavior.

Side note: do you have similar requests about cat(xs...; dims = 1)/vcat for OneHotArrays?

@simeonschaub
Copy link
Member

simeonschaub commented Jan 9, 2021

Ah, I see, so the goal here is just to preserve one-hotness as much as we can, to take advantage of fast-paths. Since OnehotArray is an AbstractArray, I think we should always try to stick to the standard AbstractArray interfaces. Having special-cased overloads for Base functions, which may also return more convenient types is completely fine, but it should just be an optimization, I don't think we should deviate from the semantics of these functions and add stricter requirements. That's also why I am a bit skeptical of the proposal in this PR, since it implicitly collects in some cases, which goes against the usual semantics of reshape that it is lazy and will not allocate. I wonder if we maybe could adapt some of the fast paths to ReshapedArray{T,N,<:OnehotArray} as well and thereby avoid rewrapping in reshape altogether?

Side note: do you have similar requests about cat(xs...; dims = 1)/vcat for OneHotArrays?

Yes, I think by the same reasoning that should be allowed as well. In that case, I don't think the type instabilities would be that big of a deal, since cat in Base is type unstable already and typically dims will be a constant.

@darsnack
Copy link
Member Author

darsnack commented Jan 9, 2021

I wonder if we maybe could adapt some of the fast paths to ReshapedArray{T,N,<:OnehotArray} as well and thereby avoid rewrapping in reshape altogether?

That makes sense. Let me look into what that would entail for the paths we want to hit.

@DhairyaLGandhi are you okay with allowing vcat for OneHotArrays that returns a Bool array? Previously, you had some reservations.

@CarloLucibello
Copy link
Member

CarloLucibello commented Jan 9, 2021

reshape is not part of the abstract array interface (nor is vcat) so we don't necessarily have to support all kinds of reshapes, especially if not semantically meaningful. I really wonder what was @simeonschaub's specific use case for reshaping the first dimension, I couldn't guess

@darsnack
Copy link
Member Author

darsnack commented Jan 9, 2021

That's true. I am pretty sure for these use cases, collect before reshape should fix this issue for Flux#master. I'm mentioning this for anyone else's benefit that comes across this.

Also, we can should merge this PR with the simple fallback fixes so that they make it in for v0.12. I can submit an additional PR extending ReshapedArray.

@mcabbott
Copy link
Member

mcabbott commented Jan 9, 2021

vcat follows from indexing, it seems pretty odd to break it, for something <: AbstractArray. Was there an argument for breaking that, instead of just leaving the default?

julia> struct Mine <: AbstractVector{Bool} x end
       Base.size(x::Mine) = size(x.x)
       Base.getindex(x::Mine, i::Integer) = x.x[i]
       vcat(Mine(randn(2).>0), Mine(randn(2).>0))
4-element Array{Bool,1}:
 1
 0
 0
 1

@darsnack
Copy link
Member Author

darsnack commented Jan 9, 2021

I think the argument is the same:

I guess the actual reason might have been that when you reshape the first dimension, the array is no longer one-hot, so we erred on the side of alerting the user instead of silently converting (the same was done for cating along the first dimension).

But I feel now that these fallback cases are never on a critical path, so it isn't so bad to silently do what's expected instead of forcing a manual collect.

@darsnack darsnack changed the title Use fallback for reshaping OneHotArray Use fallback for reshape/cat OneHotArray Jan 9, 2021
@mcabbott
Copy link
Member

mcabbott commented Jan 9, 2021

I guess the alternative would not to pretend to be an array, and then it would only ever work on the intended straight-and-narrow. But if it is an array, then it seems like it ought to work in weird things people think up. Although where it's better to collect vs. producing some multiply wrapped thing I don't know.

@darsnack
Copy link
Member Author

darsnack commented Jan 9, 2021

So FYI, supporting these fallback paths is not hard to add, and in both cases we just convert to a Bool array (so after the conversion Base supports what happens).

Most operations on a OneHotArray will work as a result of getindex. I now agree that if we specialize a few operations for speed, we should continue to support the slower fallback path.

@darsnack
Copy link
Member Author

darsnack commented Jan 9, 2021

For the reshape case, it would be nice to have a lazy iterator to avoid data movement. Supporting it would help both the fallback and non-fallback cases.

@mcabbott
Copy link
Member

mcabbott commented Jan 9, 2021

Maybe this was said, but the one downside of simple fallbacks is type-stability. I guess cat isn't type-stable anyway but hcat could be; reshape I don't know if there's an easy solution, you could have a different function reshape1 which does insist on making a onehotarray? Not sure how much this matters.

@darsnack
Copy link
Member Author

darsnack commented Jan 9, 2021

Yeah, hcat can easily be made type stable in this PR. I guess it depends where these operations show up in practice. The reason for pushing for OneHotArrays as much as possible is because they were in #1447. If that's on a critical path then forcing collect might be a better option. We can make the error message more informative.

@simeonschaub
Copy link
Member

reshape is not part of the abstract array interface (nor is vcat) so we don't necessarily have to support all kinds of reshapes, especially if not semantically meaningful.

This part in the manual just describes a minimal number of functions new subtypes of AbstractArray should implement so generic functions for AbstractArrays work like they are supposed to. Interfaces in Julia currently aren't very formalized, but I definitely believe it's bad style to break (although maybe only implicitly assumed by everyone) promises of generic function. I think most users would assume reshape(a, dims...) to mean "lazily give me the underlying column-major data of a, but indexed as if it had size dims". If a has special properties, those might not necessarily be preserved under reshape, for example, we allow reshape(::Transpose, ...) to work, even though the result might not be representable as a Transpose object. For generic code, it is often critical, that these promises still hold, since at the end of the day subtypes of AbstractArray should just be indexable containers – any additional properties are secondary.

I really wonder what was @simeonschaub's specific use case for reshaping the first dimension, I couldn't guess

In my case, I added a singleton dimension in front, so I could broadcast with another array. This is something that's definitely not rare in generic code and it is quite annoying to work around if it doesn't work in all cases.

@CarloLucibello
Copy link
Member

@simeonschaub we lose type stability for reshape though, is allowing to reshape the first dimension worth it?

@simeonschaub
Copy link
Member

Not necessarily, see my first comment, where I proposed not overloading reshape at all, but widening some fast paths to accept ReshapedArray as well.

@DhairyaLGandhi
Copy link
Member

I think that's the better approach

@CarloLucibello
Copy link
Member

Not necessarily, see my first comment, where I proposed not overloading reshape at all, but widening some fast paths to accept ReshapedArray as well.

Are you suggesting to always wrap in a ReshapedArray when reshaping? Array wrappers are always a bit annoying to work with, and they sometimes interact badly with the cuda infrastructure AFAIK. I'm not super familiar with the topic, but there has to be some reason why reshaping an Array returns an Array, while ReshapedArray is used as a generic fallback for AbstractArrays

@mcabbott
Copy link
Member

widening some fast paths to accept ReshapedArray

What fast paths are there, BTW? Just * and argmax, or have I missed some?

I wondered whether onehot .* matrix ought to specialise, and onehot .* fun.(matrix) could too, but not sure how much you gain especially once Zygote is involved.

@DhairyaLGandhi
Copy link
Member

bumping @darsnack to widen the types and get rid of the reshape overload.

@darsnack
Copy link
Member Author

I am assuming you are referring to the ReshapedArray implementation. Planning on working on that today after I address the outstanding Metalhead.jl PR which is blocking other work.

@darsnack
Copy link
Member Author

Sorry for the delay. I swapped overloading reshape to allowing ReshapedArray{T, L, <:OneHotArray} to hit the fast paths too. This works by widening the types of the fast path functions and using the reshaped x.indices for reshaped one hot arrays. To prevent reshaped one hot arrays that are no longer one hot (i.e. first dim is reshaped) from hitting the fast paths, the array is converted to Array{Bool}/CuArray{Bool} before reshaping. It will still be a wrapped ReshapedArray, just not treated as one hot.

@darsnack
Copy link
Member Author

Bump

src/onehot.jl Outdated Show resolved Hide resolved
Copy link
Member

@simeonschaub simeonschaub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some small suggestions, sorry for holding this up. Otherwise looks good!

src/onehot.jl Outdated Show resolved Hide resolved
src/onehot.jl Show resolved Hide resolved
src/onehot.jl Show resolved Hide resolved
@darsnack
Copy link
Member Author

Thanks, think I need an approval from @CarloLucibello or @DhairyaLGandhi

Copy link
Member

@DhairyaLGandhi DhairyaLGandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I find the implementation to have more "helper" methods than I'd care for, seems like we ought to make oharray code more generic, but I guess that's going to have to wait for a different day

src/onehot.jl Show resolved Hide resolved
# use this type so reshaped arrays hit fast paths
# e.g. argmax
const OneHotLike{T, L, N, var"N+1", I} =
Union{OneHotArray{T, L, N, var"N+1", I},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Man this N+1 is tripping me up, I would say we need to remove this soon. Where is it used exactly?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we could calculate var"N+1" during runtime?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like it either! It can't be done at runtime since N and var"N+1" are used in the type specification. N is used to specify the type of the index array, and var"N+1" is used to inherit from AbstractArray{Bool, var"N+1"}. Neither is evaluated at runtime.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could change it to another variable. I don't have strong feelings, but a part of me says that at least this naming signals the intent of the type parameter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be fair, I did mean we would have to switch it out during construction, because I don't think it's any better for dispatch to have to do checks on ints than types. To me it suggests that it is a preknown quantity so adding it to the type doesn't win us much.

src/onehot.jl Show resolved Hide resolved
@darsnack
Copy link
Member Author

Ah apparently I don't have merge rights. So you will still need to merge this.

@DhairyaLGandhi
Copy link
Member

bors r+

bors bot added a commit that referenced this pull request Jan 29, 2021
1459: Use fallback for reshape/cat OneHotArray r=DhairyaLGandhi a=darsnack

This falls back to reshaping a `Bool` array whenever reshaping the first dimension of a `OneHotArray`.

@DhairyaLGandhi @CarloLucibello @simeonschaub 

### PR Checklist

- [x] Tests are added
- [ ] ~~Entry in NEWS.md~~
- [x] Documentation, if applicable


Co-authored-by: Kyle Daruwalla <[email protected]>
Co-authored-by: Kyle Daruwalla <[email protected]>
@bors
Copy link
Contributor

bors bot commented Jan 29, 2021

This PR was included in a batch that successfully built, but then failed to merge into master (it was a non-fast-forward update). It will be automatically retried.

@DhairyaLGandhi
Copy link
Member

Hmm..

Let try again

bors r+

@bors
Copy link
Contributor

bors bot commented Jan 29, 2021

Already running a review

@bors
Copy link
Contributor

bors bot commented Jan 29, 2021

Build succeeded:

@bors bors bot merged commit cf0042c into FluxML:master Jan 29, 2021
@darsnack darsnack deleted the darsnack/onehot-reshape-fix branch June 15, 2021 19:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants