From 23757a27d42021e6b0a4c9ec8fafbbddfebad195 Mon Sep 17 00:00:00 2001 From: tsai Date: Wed, 6 Dec 2023 15:38:43 +0800 Subject: [PATCH 1/9] add more constraints --- dev-requirements.txt | 2 +- lib/beaver/slang.ex | 17 ++++++++++++++++- test/elixir_ast_dialect_test.exs | 1 + test/support/elixir_ast_dialect.ex | 10 +++++----- 4 files changed, 23 insertions(+), 7 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 6977ccc7d..4c311c6f3 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,2 +1,2 @@ -mlir==18.0.0.2023110801+c86d35a +mlir==18.0.0.2023120601+8f9aac44 --find-links https://github.com/makslevental/mlir-wheels/releases/expanded_assets/latest diff --git a/lib/beaver/slang.ex b/lib/beaver/slang.ex index 5c9c65005..1bc6f9de6 100644 --- a/lib/beaver/slang.ex +++ b/lib/beaver/slang.ex @@ -235,6 +235,13 @@ defmodule Beaver.Slang do for {v, i} <- Enum.with_index(args), do: transform_arg(v, i, :variable) end + @doc false + def gen_region(:any, block, ctx) do + mlir ctx: ctx, block: block do + IRDL.region() >>> ~t{!irdl.region} + end + end + # This function generates the AST for a creator function for an IRDL operation (like `irdl.operation`, `irdl.type`). It uses the transform_defop_pins/1 function to transform the pins, generates the MLIR code for the operation and its arguments, and applies the operation using op_applier/1. defp gen_creator(op, args_op, call, do_block, opts \\ []) do {name, args} = call |> Macro.decompose_call() @@ -264,6 +271,13 @@ defmodule Beaver.Slang do mlir block: opts[:block], ctx: opts[:ctx] do unquote_splicing(input_constrains) + + if unquote(opts[:regions_block]) do + unquote(opts[:regions_block]) + |> List.wrap() + |> Enum.map(&Beaver.Slang.gen_region(&1, opts[:block], opts[:ctx])) + end + {[unquote_splicing(args_var_ast)], unquote(do_block)} end end, @@ -341,7 +355,8 @@ defmodule Beaver.Slang do defmacro defop(call, block \\ nil) do gen_creator(:operation, :operands, call, block[:do], return_op: :results, - need_variadicity: true + need_variadicity: true, + regions_block: block[:regions] ) end diff --git a/test/elixir_ast_dialect_test.exs b/test/elixir_ast_dialect_test.exs index 31640f5b3..dcf5287c5 100644 --- a/test/elixir_ast_dialect_test.exs +++ b/test/elixir_ast_dialect_test.exs @@ -27,5 +27,6 @@ defmodule ELXDialectTest do end |> ElixirAST.from_ast(ctx: test_context[:ctx]) |> MLIR.dump!() + |> MLIR.Operation.verify!() end end diff --git a/test/support/elixir_ast_dialect.ex b/test/support/elixir_ast_dialect.ex index b5a6a9878..a72dd77d8 100644 --- a/test/support/elixir_ast_dialect.ex +++ b/test/support/elixir_ast_dialect.ex @@ -5,13 +5,13 @@ defmodule ElixirAST do deftype dyn deftype bound deftype unbound - defop mod(), do: [] - defop func(), do: [] + defop mod(), do: [], regions: [:any] + defop func(), do: [], regions: [:any] defop lit_int(), do: [Type.i64()] - defop var(), do: [] - defop bind(), do: [] + defop var(), do: [any()] + defop bind(any(), any()), do: [any()] defop call(), do: [] - defop add(), do: [] + defop add(any(), any()), do: [any()] alias Beaver.MLIR.Type alias Beaver.MLIR.Attribute use Beaver From b734492f06321b43bdfb2dc308b66e9d2fb7b0da Mon Sep 17 00:00:00 2001 From: tsai Date: Fri, 8 Dec 2023 10:57:12 +0800 Subject: [PATCH 2/9] Update elixir_ast_dialect.ex Update elixir_ast_dialect.ex --- test/support/elixir_ast_dialect.ex | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/test/support/elixir_ast_dialect.ex b/test/support/elixir_ast_dialect.ex index a72dd77d8..3cb2b2665 100644 --- a/test/support/elixir_ast_dialect.ex +++ b/test/support/elixir_ast_dialect.ex @@ -10,7 +10,7 @@ defmodule ElixirAST do defop lit_int(), do: [Type.i64()] defop var(), do: [any()] defop bind(any(), any()), do: [any()] - defop call(), do: [] + defop call(), do: [any()] defop add(any(), any()), do: [any()] alias Beaver.MLIR.Type alias Beaver.MLIR.Attribute @@ -28,14 +28,12 @@ defmodule ElixirAST do block ) do mlir ctx: ctx, block: block do - __MODULE__.mod name: "\"#{name}\"" do - region do - block functions do - Macro.prewalk(do_body, &gen_mlir(&1, ctx, Beaver.Env.block())) - end + m = + module do + Macro.prewalk(do_body, &gen_mlir(&1, ctx, Beaver.Env.block())) end - end >>> - [] + + MLIR.CAPI.mlirBlockAppendOwnedOperation(block, MLIR.Operation.from_module(m)) end end @@ -51,10 +49,21 @@ defmodule ElixirAST do block ) do mlir ctx: ctx, block: block do - __MODULE__.func name: "\"#{name}\"" do + alias Beaver.MLIR.Dialect.Func + + Func.func sym_name: "\"#{name}\"", function_type: Type.function([], []) do region do block func_body do - Macro.prewalk(do_body, &gen_mlir(&1, ctx, Beaver.Env.block())) + ret_value = + case Macro.prewalk(do_body, &gen_mlir(&1, ctx, Beaver.Env.block())) do + {:__block__, [], values} -> + values |> List.last() + + %Beaver.MLIR.Value{} = v -> + v + end + + Func.return() >>> [] end end end >>> From b627343ce4bf88b8be41db3165101fd6c39b0cc5 Mon Sep 17 00:00:00 2001 From: tsai Date: Fri, 8 Dec 2023 11:09:12 +0800 Subject: [PATCH 3/9] use builtin mod and func --- test/support/elixir_ast_dialect.ex | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/test/support/elixir_ast_dialect.ex b/test/support/elixir_ast_dialect.ex index 3cb2b2665..3b272a09a 100644 --- a/test/support/elixir_ast_dialect.ex +++ b/test/support/elixir_ast_dialect.ex @@ -5,8 +5,6 @@ defmodule ElixirAST do deftype dyn deftype bound deftype unbound - defop mod(), do: [], regions: [:any] - defop func(), do: [], regions: [:any] defop lit_int(), do: [Type.i64()] defop var(), do: [any()] defop bind(any(), any()), do: [any()] @@ -29,7 +27,7 @@ defmodule ElixirAST do ) do mlir ctx: ctx, block: block do m = - module do + module sym_name: ~a{"#{name}"} do Macro.prewalk(do_body, &gen_mlir(&1, ctx, Beaver.Env.block())) end @@ -51,10 +49,11 @@ defmodule ElixirAST do mlir ctx: ctx, block: block do alias Beaver.MLIR.Dialect.Func - Func.func sym_name: "\"#{name}\"", function_type: Type.function([], []) do + Func.func sym_name: "\"#{name}\"", function_type: Type.function([], [dyn()]) do region do block func_body do - ret_value = + %MLIR.Value{} = + ret_value = case Macro.prewalk(do_body, &gen_mlir(&1, ctx, Beaver.Env.block())) do {:__block__, [], values} -> values |> List.last() @@ -63,7 +62,7 @@ defmodule ElixirAST do v end - Func.return() >>> [] + Func.return(ret_value) >>> [] end end end >>> From f62cd353c14929df399f9f49f2733bca0d48c67c Mon Sep 17 00:00:00 2001 From: tsai Date: Wed, 13 Dec 2023 14:47:07 +0800 Subject: [PATCH 4/9] support sized region --- lib/beaver/slang.ex | 7 +++++++ test/irdl_test.exs | 6 +++++- test/support/region_dialect.ex | 6 ++++++ 3 files changed, 18 insertions(+), 1 deletion(-) create mode 100644 test/support/region_dialect.ex diff --git a/lib/beaver/slang.ex b/lib/beaver/slang.ex index 1bc6f9de6..7579353c8 100644 --- a/lib/beaver/slang.ex +++ b/lib/beaver/slang.ex @@ -242,6 +242,13 @@ defmodule Beaver.Slang do end end + def gen_region({:sized, size}, block, ctx) when is_integer(size) do + mlir ctx: ctx, block: block do + IRDL.region(numberOfBlocks: MLIR.Attribute.integer(MLIR.Type.i32(), size)) >>> + ~t{!irdl.region} + end + end + # This function generates the AST for a creator function for an IRDL operation (like `irdl.operation`, `irdl.type`). It uses the transform_defop_pins/1 function to transform the pins, generates the MLIR code for the operation and its arguments, and applies the operation using op_applier/1. defp gen_creator(op, args_op, call, do_block, opts \\ []) do {name, args} = call |> Macro.decompose_call() diff --git a/test/irdl_test.exs b/test/irdl_test.exs index 0008f12e8..8eeb1ad20 100644 --- a/test/irdl_test.exs +++ b/test/irdl_test.exs @@ -67,7 +67,11 @@ defmodule IRDLTest do test "var dialect", test_context do - use Beaver TestVariadic.__slang_dialect__(test_context[:ctx]) |> MLIR.Operation.verify!() end + + test "region dialect", + test_context do + TestRegion.__slang_dialect__(test_context[:ctx]) |> MLIR.Operation.verify!() + end end diff --git a/test/support/region_dialect.ex b/test/support/region_dialect.ex new file mode 100644 index 000000000..faeaaca12 --- /dev/null +++ b/test/support/region_dialect.ex @@ -0,0 +1,6 @@ +defmodule TestRegion do + @moduledoc "An example to showcase the a region dialect in IRDL test." + use Beaver.Slang, name: "test_region" + defop any_region(i = {:single, Type.i32()}), do: [], regions: [:any] + defop sized_region(), regions: [{:sized, 11}] +end From 78159b87ef77b6e46013b671f7ee33c874a14dba Mon Sep 17 00:00:00 2001 From: tsai Date: Wed, 13 Dec 2023 14:52:30 +0800 Subject: [PATCH 5/9] add more cases --- test/support/region_dialect.ex | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/support/region_dialect.ex b/test/support/region_dialect.ex index faeaaca12..282fbd979 100644 --- a/test/support/region_dialect.ex +++ b/test/support/region_dialect.ex @@ -2,5 +2,6 @@ defmodule TestRegion do @moduledoc "An example to showcase the a region dialect in IRDL test." use Beaver.Slang, name: "test_region" defop any_region(i = {:single, Type.i32()}), do: [], regions: [:any] - defop sized_region(), regions: [{:sized, 11}] + defop any_region2(i = {:single, Type.i32()}), do: [], regions: [:any, :any] + defop sized_region(), regions: {:sized, 11} end From 622e5e884ea6ba2be9de08002bf3a9ef8d72de84 Mon Sep 17 00:00:00 2001 From: tsai Date: Wed, 13 Dec 2023 15:19:49 +0800 Subject: [PATCH 6/9] minor refactor --- lib/beaver/slang.ex | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/lib/beaver/slang.ex b/lib/beaver/slang.ex index 7579353c8..a12888b6e 100644 --- a/lib/beaver/slang.ex +++ b/lib/beaver/slang.ex @@ -265,6 +265,15 @@ defmodule Beaver.Slang do transform_arg(ast, i, :constrain) end) + region_constrains = + if opts[:regions_block] do + quote do + unquote(opts[:regions_block]) + |> List.wrap() + |> Enum.map(&Beaver.Slang.gen_region(&1, opts[:block], opts[:ctx])) + end + end + quote do Module.put_attribute(__MODULE__, unquote(attr_name), unquote(name)) @__slang__creator__ {unquote(op), __MODULE__, unquote(creator)} @@ -278,13 +287,7 @@ defmodule Beaver.Slang do mlir block: opts[:block], ctx: opts[:ctx] do unquote_splicing(input_constrains) - - if unquote(opts[:regions_block]) do - unquote(opts[:regions_block]) - |> List.wrap() - |> Enum.map(&Beaver.Slang.gen_region(&1, opts[:block], opts[:ctx])) - end - + unquote(region_constrains) {[unquote_splicing(args_var_ast)], unquote(do_block)} end end, From 082a31b124da434855ac526f3d1178e915efb510 Mon Sep 17 00:00:00 2001 From: tsai Date: Wed, 13 Dec 2023 15:27:47 +0800 Subject: [PATCH 7/9] Revert "minor refactor" This reverts commit 622e5e884ea6ba2be9de08002bf3a9ef8d72de84. --- lib/beaver/slang.ex | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/lib/beaver/slang.ex b/lib/beaver/slang.ex index a12888b6e..7579353c8 100644 --- a/lib/beaver/slang.ex +++ b/lib/beaver/slang.ex @@ -265,15 +265,6 @@ defmodule Beaver.Slang do transform_arg(ast, i, :constrain) end) - region_constrains = - if opts[:regions_block] do - quote do - unquote(opts[:regions_block]) - |> List.wrap() - |> Enum.map(&Beaver.Slang.gen_region(&1, opts[:block], opts[:ctx])) - end - end - quote do Module.put_attribute(__MODULE__, unquote(attr_name), unquote(name)) @__slang__creator__ {unquote(op), __MODULE__, unquote(creator)} @@ -287,7 +278,13 @@ defmodule Beaver.Slang do mlir block: opts[:block], ctx: opts[:ctx] do unquote_splicing(input_constrains) - unquote(region_constrains) + + if unquote(opts[:regions_block]) do + unquote(opts[:regions_block]) + |> List.wrap() + |> Enum.map(&Beaver.Slang.gen_region(&1, opts[:block], opts[:ctx])) + end + {[unquote_splicing(args_var_ast)], unquote(do_block)} end end, From 167af880eb8b5dead32d4e7b6a9cd330120984cf Mon Sep 17 00:00:00 2001 From: tsai Date: Wed, 13 Dec 2023 15:32:23 +0800 Subject: [PATCH 8/9] Update slang.ex --- lib/beaver/slang.ex | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/beaver/slang.ex b/lib/beaver/slang.ex index 7579353c8..82528f1a5 100644 --- a/lib/beaver/slang.ex +++ b/lib/beaver/slang.ex @@ -148,7 +148,7 @@ defmodule Beaver.Slang do op_applier slang_target_op: op, sym_name: "\"#{name}\"" do region do block b_op() do - {args, ret} = constrain_f.(block: Beaver.Env.block(), ctx: ctx) + {args, ret} = constrain_f.(Beaver.Env.block(), ctx) case strip_variadicity(args) do [] -> @@ -273,16 +273,16 @@ defmodule Beaver.Slang do unquote(name), unquote(op), unquote(args_op), - fn opts -> + fn block, ctx -> use Beaver - mlir block: opts[:block], ctx: opts[:ctx] do + mlir block: block, ctx: ctx do unquote_splicing(input_constrains) if unquote(opts[:regions_block]) do unquote(opts[:regions_block]) |> List.wrap() - |> Enum.map(&Beaver.Slang.gen_region(&1, opts[:block], opts[:ctx])) + |> Enum.map(&Beaver.Slang.gen_region(&1, block, ctx)) end {[unquote_splicing(args_var_ast)], unquote(do_block)} From 5ab3a9ce4b9dd7ddfdc85e397d2afe09f667d240 Mon Sep 17 00:00:00 2001 From: tsai Date: Wed, 13 Dec 2023 15:56:27 +0800 Subject: [PATCH 9/9] Update slang.ex --- lib/beaver/slang.ex | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/lib/beaver/slang.ex b/lib/beaver/slang.ex index 82528f1a5..425027832 100644 --- a/lib/beaver/slang.ex +++ b/lib/beaver/slang.ex @@ -235,20 +235,26 @@ defmodule Beaver.Slang do for {v, i} <- Enum.with_index(args), do: transform_arg(v, i, :variable) end - @doc false - def gen_region(:any, block, ctx) do + defp do_gen_region(:any, block, ctx) do mlir ctx: ctx, block: block do IRDL.region() >>> ~t{!irdl.region} end end - def gen_region({:sized, size}, block, ctx) when is_integer(size) do + defp do_gen_region({:sized, size}, block, ctx) when is_integer(size) do mlir ctx: ctx, block: block do IRDL.region(numberOfBlocks: MLIR.Attribute.integer(MLIR.Type.i32(), size)) >>> ~t{!irdl.region} end end + @doc false + def gen_region(regions, block, ctx) do + regions + |> List.wrap() + |> Enum.map(&do_gen_region(&1, block, ctx)) + end + # This function generates the AST for a creator function for an IRDL operation (like `irdl.operation`, `irdl.type`). It uses the transform_defop_pins/1 function to transform the pins, generates the MLIR code for the operation and its arguments, and applies the operation using op_applier/1. defp gen_creator(op, args_op, call, do_block, opts \\ []) do {name, args} = call |> Macro.decompose_call() @@ -279,11 +285,13 @@ defmodule Beaver.Slang do mlir block: block, ctx: ctx do unquote_splicing(input_constrains) - if unquote(opts[:regions_block]) do - unquote(opts[:regions_block]) - |> List.wrap() - |> Enum.map(&Beaver.Slang.gen_region(&1, block, ctx)) - end + unquote_splicing( + for {ast, {mod, func}} <- opts[:extra_constrains] || [] do + quote do + apply(unquote(mod), unquote(func), [unquote(ast), block, ctx]) + end + end + ) {[unquote_splicing(args_var_ast)], unquote(do_block)} end @@ -363,7 +371,9 @@ defmodule Beaver.Slang do gen_creator(:operation, :operands, call, block[:do], return_op: :results, need_variadicity: true, - regions_block: block[:regions] + extra_constrains: [ + {block[:regions], {Beaver.Slang, :gen_region}} + ] ) end