diff --git a/lib/beaver/slang.ex b/lib/beaver/slang.ex index af63dceeb..e2aa0fbf9 100644 --- a/lib/beaver/slang.ex +++ b/lib/beaver/slang.ex @@ -148,7 +148,7 @@ defmodule Beaver.Slang do op_applier slang_target_op: op, sym_name: "\"#{name}\"" do region do block b_op() do - {args, ret} = constrain_f.(block: Beaver.Env.block(), ctx: ctx) + {args, ret} = constrain_f.(Beaver.Env.block(), ctx) case strip_variadicity(args) do [] -> @@ -235,6 +235,26 @@ defmodule Beaver.Slang do for {v, i} <- Enum.with_index(args), do: transform_arg(v, i, :variable) end + defp do_gen_region(:any, block, ctx) do + mlir ctx: ctx, block: block do + IRDL.region() >>> ~t{!irdl.region} + end + end + + defp do_gen_region({:sized, size}, block, ctx) when is_integer(size) do + mlir ctx: ctx, block: block do + IRDL.region(numberOfBlocks: MLIR.Attribute.integer(MLIR.Type.i32(), size)) >>> + ~t{!irdl.region} + end + end + + @doc false + def gen_region(regions, block, ctx) do + regions + |> List.wrap() + |> Enum.map(&do_gen_region(&1, block, ctx)) + end + # This function generates the AST for a creator function for an IRDL operation (like `irdl.operation`, `irdl.type`). It uses the transform_defop_pins/1 function to transform the pins, generates the MLIR code for the operation and its arguments, and applies the operation using op_applier/1. defp gen_creator(op, args_op, call, do_block, opts \\ []) do {name, args} = call |> Macro.decompose_call() @@ -259,11 +279,20 @@ defmodule Beaver.Slang do unquote(name), unquote(op), unquote(args_op), - fn opts -> + fn block, ctx -> use Beaver - mlir block: opts[:block], ctx: opts[:ctx] do + mlir block: block, ctx: ctx do unquote_splicing(input_constrains) + + unquote_splicing( + for {ast, {mod, func}} <- opts[:extra_constrains] || [] do + quote do + apply(unquote(mod), unquote(func), [unquote(ast), block, ctx]) + end + end + ) + {[unquote_splicing(args_var_ast)], unquote(do_block)} end end, @@ -341,7 +370,10 @@ defmodule Beaver.Slang do defmacro defop(call, block \\ nil) do gen_creator(:operation, :operands, call, block[:do], return_op: :results, - need_variadicity: true + need_variadicity: true, + extra_constrains: [ + {block[:regions], {Beaver.Slang, :gen_region}} + ] ) end diff --git a/test/irdl_test.exs b/test/irdl_test.exs index 0008f12e8..8eeb1ad20 100644 --- a/test/irdl_test.exs +++ b/test/irdl_test.exs @@ -67,7 +67,11 @@ defmodule IRDLTest do test "var dialect", test_context do - use Beaver TestVariadic.__slang_dialect__(test_context[:ctx]) |> MLIR.Operation.verify!() end + + test "region dialect", + test_context do + TestRegion.__slang_dialect__(test_context[:ctx]) |> MLIR.Operation.verify!() + end end diff --git a/test/support/region_dialect.ex b/test/support/region_dialect.ex new file mode 100644 index 000000000..282fbd979 --- /dev/null +++ b/test/support/region_dialect.ex @@ -0,0 +1,7 @@ +defmodule TestRegion do + @moduledoc "An example to showcase the a region dialect in IRDL test." + use Beaver.Slang, name: "test_region" + defop any_region(i = {:single, Type.i32()}), do: [], regions: [:any] + defop any_region2(i = {:single, Type.i32()}), do: [], regions: [:any, :any] + defop sized_region(), regions: {:sized, 11} +end