diff --git a/src/enumeration.jl b/src/enumeration.jl index c220a3ff4..a725b518f 100644 --- a/src/enumeration.jl +++ b/src/enumeration.jl @@ -51,7 +51,7 @@ all_subtypes(ts::Vector, scfg :: SearchCfg, result :: Channel) = begin # Skip and push a marker describing the case to the caller. # # - skip unionalls due to search config - unionalls = filter(t -> typeof(t) == UnionAll, tv) + unionalls = filter(t -> t isa UnionAll, tv) if scfg.skip_unionalls && !isempty(unionalls) put!(result, SkipMandatory(Tuple(unionalls))) continue @@ -109,10 +109,12 @@ generate_subtypes(ts1::Vector, scfg :: SearchCfg) = begin end @debug "generate_subtypes of head: $(ss_first)" - # no subtypes may mean it's a UnionAll requiring special handling + # no subtypes may mean it's a UnionAll or a Union requiring special handling if isempty(ss_first) - if typeof(t) == UnionAll + if t isa UnionAll ss_first = subtype_unionall(t, scfg) + elseif t isa Union + ss_first = subtype_union(t) end end @@ -190,3 +192,21 @@ subtype_unionall(u :: UnionAll, scfg :: SearchCfg) = begin end res end + +# +# subtype_union: Union -> [JlType] +# +# This flattens the nested union to an array of its types since the `subtypes` builtin +# function only returns declared subtypes. However, for a Union, we want to process +# each of the contained types as its subtypes. +# +subtype_union(t::Union) = begin + @debug "subtype_union of $t" + res = [] + while t isa Union + push!(res, t.a) + t = t.b + end + push!(res, t) + res +end diff --git a/test/runtests.jl b/test/runtests.jl index e5be94e9d..0559b64e5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -74,6 +74,9 @@ sum_top(v, t) = begin res end +stable_funion(x::Bool, y::Union{TwoSubtypes, Bool}) = !x +unstable_funion(x::Bool, y::Union{TwoSubtypes, Bool}) = if x 1 else "" end + # Note: generic methods # We don't handle generic methods yet (#9) # rational_plusi(a::Rational{T}, b::Rational{T}) where T <: Integer = a + b @@ -126,6 +129,17 @@ end @test isa(is_stable_method((@which add1r(1.0 + 1.0im)), SearchCfg(abstract_args=true)) , Stb) end +@testset "Unions " begin + res = is_stable_method(@which stable_funion(true, true)) + @test res isa Stb + res = is_stable_method(@which unstable_funion(true, true)) + @test res isa Uns && + length(res.fails) == 3 && + [Bool, SubtypeA] in res.fails && + [Bool, SubtypeB] in res.fails && + [Bool, Bool] in res.fails +end + @testset "Special (Any, Varargs, Generic)" begin f(x)=1 @test is_stable_method(@which f(2)) == AnyParam(Any[Any])