Skip to content

Commit

Permalink
correct splatting of tail add first test for multiple dimensions
Browse files Browse the repository at this point in the history
Co-Authored-By: Dhairya Gandhi <[email protected]>
  • Loading branch information
2 people authored and hhaensel committed Jul 1, 2020
1 parent 39b14fb commit d492c0f
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ outdims(m, (10, 2)) == (5, 2)
"""
function outdims(l::Dense, isize)
first(isize) == size(l.W, 2) || throw(DimensionMismatch("input size should equal to ($(size(l.W, 2)), ...), got $isize"))
return (size(l.W, 1), Base.tail(isize))
return (size(l.W, 1), Base.tail(isize)...)
end

"""
Expand Down
1 change: 1 addition & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ import Flux: activations
m = Dense(10, 5)
@test_throws DimensionMismatch Flux.outdims(m, (5, 2)) == (5,)
@test Flux.outdims(m, (10,)) == (5,)
@test Flux.outdims(m, (10, 2)) == (5, 2)

m = Chain(Dense(10, 8, σ), Dense(8, 5), Dense(5, 2))
@test Flux.outdims(m, (10,)) == (2,)
Expand Down

0 comments on commit d492c0f

Please sign in to comment.