Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve IRDL region constraints #234

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions lib/beaver/slang.ex
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ defmodule Beaver.Slang do
op_applier slang_target_op: op, sym_name: "\"#{name}\"" do
region do
block b_op() do
{args, ret} = constrain_f.(block: Beaver.Env.block(), ctx: ctx)
{args, ret} = constrain_f.(Beaver.Env.block(), ctx)

case strip_variadicity(args) do
[] ->
Expand Down Expand Up @@ -235,6 +235,26 @@ defmodule Beaver.Slang do
for {v, i} <- Enum.with_index(args), do: transform_arg(v, i, :variable)
end

defp do_gen_region(:any, block, ctx) do
mlir ctx: ctx, block: block do
IRDL.region() >>> ~t{!irdl.region}
end
end

defp do_gen_region({:sized, size}, block, ctx) when is_integer(size) do
mlir ctx: ctx, block: block do
IRDL.region(numberOfBlocks: MLIR.Attribute.integer(MLIR.Type.i32(), size)) >>>
~t{!irdl.region}
end
end

@doc false
def gen_region(regions, block, ctx) do
regions
|> List.wrap()
|> Enum.map(&do_gen_region(&1, block, ctx))
end

# This function generates the AST for a creator function for an IRDL operation (like `irdl.operation`, `irdl.type`). It uses the transform_defop_pins/1 function to transform the pins, generates the MLIR code for the operation and its arguments, and applies the operation using op_applier/1.
defp gen_creator(op, args_op, call, do_block, opts \\ []) do
{name, args} = call |> Macro.decompose_call()
Expand All @@ -259,11 +279,20 @@ defmodule Beaver.Slang do
unquote(name),
unquote(op),
unquote(args_op),
fn opts ->
fn block, ctx ->
use Beaver

mlir block: opts[:block], ctx: opts[:ctx] do
mlir block: block, ctx: ctx do
unquote_splicing(input_constrains)

unquote_splicing(
for {ast, {mod, func}} <- opts[:extra_constrains] || [] do
quote do
apply(unquote(mod), unquote(func), [unquote(ast), block, ctx])
end
end
)

{[unquote_splicing(args_var_ast)], unquote(do_block)}
end
end,
Expand Down Expand Up @@ -341,7 +370,10 @@ defmodule Beaver.Slang do
defmacro defop(call, block \\ nil) do
gen_creator(:operation, :operands, call, block[:do],
return_op: :results,
need_variadicity: true
need_variadicity: true,
extra_constrains: [
{block[:regions], {Beaver.Slang, :gen_region}}
]
)
end

Expand Down
6 changes: 5 additions & 1 deletion test/irdl_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ defmodule IRDLTest do

test "var dialect",
test_context do
use Beaver
TestVariadic.__slang_dialect__(test_context[:ctx]) |> MLIR.Operation.verify!()
end

test "region dialect",
test_context do
TestRegion.__slang_dialect__(test_context[:ctx]) |> MLIR.Operation.verify!()
end
end
7 changes: 7 additions & 0 deletions test/support/region_dialect.ex
Copy link
Contributor Author

@jackalcooper jackalcooper Dec 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dedicated regions block won't work well when we want to add a constrain like this:

  • entry block argument#1 has the same type of a result (because it can't access variable in the do block)

Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defmodule TestRegion do
@moduledoc "An example to showcase the a region dialect in IRDL test."
use Beaver.Slang, name: "test_region"
defop any_region(i = {:single, Type.i32()}), do: [], regions: [:any]
defop any_region2(i = {:single, Type.i32()}), do: [], regions: [:any, :any]
defop sized_region(), regions: {:sized, 11}
end
Loading