diff --git a/lib/elixir/lib/module/types/descr.ex b/lib/elixir/lib/module/types/descr.ex index adc17fa9e32..81a5ee79909 100644 --- a/lib/elixir/lib/module/types/descr.ex +++ b/lib/elixir/lib/module/types/descr.ex @@ -1278,8 +1278,130 @@ defmodule Module.Types.Descr do defp map_only?(descr), do: empty?(Map.delete(descr, :map)) - # Union is list concatenation - defp map_union(dnf1, dnf2), do: dnf1 ++ (dnf2 -- dnf1) + defp map_union(dnf1, dnf2) do + # Union is just concatenation, but we rely on some optimization strategies to + # avoid the list to grow when possible + + # first pass trying to identify patterns where two maps can be fused as one + with [{tag1, pos1, []}] <- dnf1, + [{tag2, pos2, []}] <- dnf2, + strategy when strategy != nil <- map_union_optimization_strategy(tag1, pos1, tag2, pos2) do + case strategy do + :all_equal -> + dnf1 + + :any_map -> + [{:open, %{}, []}] + + {:one_key_difference, key, v1, v2} -> + new_pos = Map.put(pos1, key, union(v1, v2)) + [{tag1, new_pos, []}] + + :left_subtype_of_right -> + dnf2 + + :right_subtype_of_left -> + dnf1 + end + else + # otherwise we just concatenate and remove structural duplicates + _ -> dnf1 ++ (dnf2 -- dnf1) + end + end + + defp map_union_optimization_strategy(tag1, pos1, tag2, pos2) + defp map_union_optimization_strategy(tag, pos, tag, pos), do: :all_equal + defp map_union_optimization_strategy(:open, empty, _, _) when empty == %{}, do: :any_map + defp map_union_optimization_strategy(_, _, :open, empty) when empty == %{}, do: :any_map + + defp map_union_optimization_strategy(tag, pos1, tag, pos2) + when map_size(pos1) == map_size(pos2) do + :maps.iterator(pos1) + |> :maps.next() + |> do_map_union_optimization_strategy(pos2, :all_equal) + end + + defp map_union_optimization_strategy(:open, pos1, _, pos2) + when map_size(pos1) <= map_size(pos2) do + :maps.iterator(pos1) + |> :maps.next() + |> do_map_union_optimization_strategy(pos2, :right_subtype_of_left) + end + + defp map_union_optimization_strategy(_, pos1, :open, pos2) + when map_size(pos1) >= map_size(pos2) do + :maps.iterator(pos2) + |> :maps.next() + |> do_map_union_optimization_strategy(pos1, :right_subtype_of_left) + |> case do + :right_subtype_of_left -> :left_subtype_of_right + nil -> nil + end + end + + defp map_union_optimization_strategy(_, _, _, _), do: nil + + defp do_map_union_optimization_strategy(:none, _, status), do: status + + defp do_map_union_optimization_strategy({key, v1, iterator}, pos2, status) do + with %{^key => v2} <- pos2, + next_status when next_status != nil <- map_union_next_strategy(key, v1, v2, status) do + do_map_union_optimization_strategy(:maps.next(iterator), pos2, next_status) + else + _ -> nil + end + end + + defp map_union_next_strategy(key, v1, v2, status) + + # structurally equal values do not impact the ongoing strategy + defp map_union_next_strategy(_key, same, same, status), do: status + + defp map_union_next_strategy(key, v1, v2, :all_equal) do + if key != :__struct__, do: {:one_key_difference, key, v1, v2} + end + + defp map_union_next_strategy(_key, v1, v2, {:one_key_difference, _, d1, d2}) do + # we have at least two key differences now, we switch strategy + # if both are subtypes in one direction, keep checking + cond do + trivial_subtype?(d1, d2) and trivial_subtype?(v1, v2) -> :left_subtype_of_right + trivial_subtype?(d2, d1) and trivial_subtype?(v2, v1) -> :right_subtype_of_left + true -> nil + end + end + + defp map_union_next_strategy(_key, v1, v2, :left_subtype_of_right) do + if trivial_subtype?(v1, v2), do: :left_subtype_of_right + end + + defp map_union_next_strategy(_key, v1, v2, :right_subtype_of_left) do + if trivial_subtype?(v2, v1), do: :right_subtype_of_left + end + + # cheap to compute sub-typing + # a trivial subtype is always a subtype, but not all subtypes are subtypes + defp trivial_subtype?(_, :term), do: true + defp trivial_subtype?(same, same), do: true + + defp trivial_subtype?(%{} = left, %{} = right) + when map_size(left) == 1 and map_size(right) == 1 do + case {left, right} do + {%{atom: _}, %{atom: {:negation, neg}}} when neg == %{} -> + true + + {%{map: _}, %{map: [{:open, pos, []}]}} when pos == %{} -> + true + + {%{bitmap: bitmap1}, %{bitmap: bitmap2}} -> + (bitmap1 &&& bitmap2) === bitmap2 + + _ -> + false + end + end + + defp trivial_subtype?(_, _), do: false # Given two unions of maps, intersects each pair of maps. defp map_intersection(dnf1, dnf2) do diff --git a/lib/elixir/test/elixir/module/types/descr_test.exs b/lib/elixir/test/elixir/module/types/descr_test.exs index 2ffcb977869..69cd66f30d9 100644 --- a/lib/elixir/test/elixir/module/types/descr_test.exs +++ b/lib/elixir/test/elixir/module/types/descr_test.exs @@ -105,6 +105,47 @@ defmodule Module.Types.DescrTest do assert union(difference(list(term()), list(integer())), list(integer())) |> equal?(list(term())) end + + test "optimizations" do + # The tests are checking the actual implementation, not the semantics. + # This is why we are using structural comparisons. + # It's fine to remove these if the implementation changes, but breaking + # these might have an important impact on compile times. + + # Optimization one: same tags, all but one key are structurally equal + assert union( + open_map(a: float(), b: atom()), + open_map(a: integer(), b: atom()) + ) == open_map(a: union(float(), integer()), b: atom()) + + assert union( + closed_map(a: float(), b: atom()), + closed_map(a: integer(), b: atom()) + ) == closed_map(a: union(float(), integer()), b: atom()) + + # Optimization two: we can tell that one map is a trivial subtype of the other: + + assert union( + closed_map(a: term(), b: term()), + closed_map(a: float(), b: binary()) + ) == closed_map(a: term(), b: term()) + + assert union( + open_map(a: term()), + closed_map(a: float(), b: binary()) + ) == open_map(a: term()) + + assert union( + closed_map(a: float(), b: binary()), + open_map(a: term()) + ) == open_map(a: term()) + + # Do we want this want to pass or keep shallow checks only? + # assert union( + # closed_map(a: term(), b: tuple([term(), term()])), + # closed_map(a: float(), b: tuple([atom(), binary()])) + # ) == closed_map(a: term(), b: term()) + end end describe "intersection" do