Skip to content

Commit

Permalink
Extend TrivialTaskletElimination for map scope (#1650)
Browse files Browse the repository at this point in the history
Extend the transformation `TrivialTaskletElimination` for the case where
the input or output of the copy-tasklet is a map node.

The following SDFG:
<img width="266" alt="image"
src="https://github.com/user-attachments/assets/6e231bbf-d736-4dcf-b132-2e7d59c26ad5">

is transformed to this SDFG:
<img width="343" alt="image"
src="https://github.com/user-attachments/assets/82ec07b1-6b3d-421f-bca7-5c4b3bd1f320">
  • Loading branch information
edopao authored Oct 18, 2024
1 parent 4fbeba4 commit 975a065
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 17 deletions.
48 changes: 31 additions & 17 deletions dace/transformation/dataflow/trivial_tasklet_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,62 @@ class TrivialTaskletElimination(transformation.SingleStateTransformation):
"""

read = transformation.PatternNode(nodes.AccessNode)
read_map = transformation.PatternNode(nodes.MapEntry)
tasklet = transformation.PatternNode(nodes.Tasklet)
write = transformation.PatternNode(nodes.AccessNode)
write_map = transformation.PatternNode(nodes.MapExit)

@classmethod
def expressions(cls):
return [sdutil.node_path_graph(cls.read, cls.tasklet, cls.write)]
return [
sdutil.node_path_graph(cls.read, cls.tasklet, cls.write),
sdutil.node_path_graph(cls.read_map, cls.tasklet, cls.write),
sdutil.node_path_graph(cls.read, cls.tasklet, cls.write_map),
]

def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
read = self.read
read = self.read_map if expr_index == 1 else self.read
tasklet = self.tasklet
write = self.write
# Do not apply on Streams
if isinstance(sdfg.arrays[read.data], data.Stream):
return False
if isinstance(sdfg.arrays[write.data], data.Stream):
write = self.write_map if expr_index == 2 else self.write
if len(tasklet.in_connectors) != 1:
return False
if len(graph.in_edges(tasklet)) != 1:
return False
if len(graph.out_edges(tasklet)) != 1:
return False
if graph.edges_between(tasklet, write)[0].data.wcr:
return False
if len(tasklet.in_connectors) != 1:
return False
if len(tasklet.out_connectors) != 1:
return False
if len(graph.out_edges(tasklet)) != 1:
return False
in_conn = list(tasklet.in_connectors.keys())[0]
out_conn = list(tasklet.out_connectors.keys())[0]
if tasklet.code.as_string != f'{out_conn} = {in_conn}':
return False

read_memlet = graph.edges_between(read, tasklet)[0].data
read_desc = sdfg.arrays[read_memlet.data]
write_memlet = graph.edges_between(tasklet, write)[0].data
if write_memlet.wcr:
return False
write_desc = sdfg.arrays[write_memlet.data]
# Do not apply on streams
if isinstance(read_desc, data.Stream):
return False
if isinstance(write_desc, data.Stream):
return False
# Keep copy-tasklet connected to map node if source and destination nodes
# have different data type (implicit type cast)
if expr_index != 0 and read_desc.dtype != write_desc.dtype:
return False

return True

def apply(self, graph, sdfg):
read = self.read
read = self.read_map if self.expr_index == 1 else self.read
tasklet = self.tasklet
write = self.write
write = self.write_map if self.expr_index == 2 else self.write

in_edge = graph.edges_between(read, tasklet)[0]
out_edge = graph.edges_between(tasklet, write)[0]
graph.remove_edge(in_edge)
graph.remove_edge(out_edge)
out_edge.data.other_subset = in_edge.data.subset
graph.add_nedge(read, write, out_edge.data)
graph.add_edge(read, in_edge.src_conn, write, out_edge.dst_conn, out_edge.data)
graph.remove_node(tasklet)
129 changes: 129 additions & 0 deletions tests/transformations/trivial_tasklet_elimination_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
import dace
from dace.transformation.dataflow.trivial_tasklet_elimination import TrivialTaskletElimination


N = 10


def test_trivial_tasklet():
ty_ = dace.int32
sdfg = dace.SDFG("trivial_tasklet")
sdfg.add_symbol("s", ty_)
sdfg.add_array("v", (N,), ty_)
st = sdfg.add_state()

tmp1_name, _ = sdfg.add_scalar(sdfg.temp_data_name(), ty_, transient=True)
tmp1_node = st.add_access(tmp1_name)
init_tasklet = st.add_tasklet("init", {}, {"out"}, "out = s")
st.add_edge(init_tasklet, "out", tmp1_node, None, dace.Memlet(tmp1_node.data))

tmp2_name, _ = sdfg.add_scalar(sdfg.temp_data_name(), ty_, transient=True)
tmp2_node = st.add_access(tmp2_name)
copy_tasklet = st.add_tasklet("copy", {"inp"}, {"out"}, "out = inp")
st.add_edge(tmp1_node, None, copy_tasklet, "inp", dace.Memlet(tmp1_node.data))
st.add_edge(copy_tasklet, "out", tmp2_node, None, dace.Memlet(tmp2_node.data))

bcast_tasklet, _, _ = st.add_mapped_tasklet(
"bcast",
dict(i=f"0:{N}"),
inputs={"inp": dace.Memlet(f"{tmp2_node.data}[0]")},
input_nodes={tmp2_node.data: tmp2_node},
code="out = inp",
outputs={"out": dace.Memlet("v[i]")},
external_edges=True,
)

sdfg.validate()
tasklet_nodes = {x for x in st.nodes() if isinstance(x, dace.nodes.Tasklet)}
assert tasklet_nodes == {init_tasklet, copy_tasklet, bcast_tasklet}

count = sdfg.apply_transformations_repeated(TrivialTaskletElimination)
assert count == 1

assert len(st.out_edges(tmp1_node)) == 1
assert st.out_edges(tmp1_node)[0].dst == tmp2_node

tasklet_nodes = {x for x in st.nodes() if isinstance(x, dace.nodes.Tasklet)}
assert tasklet_nodes == {init_tasklet, bcast_tasklet}


def test_trivial_tasklet_with_map():
ty_ = dace.int32
sdfg = dace.SDFG("trivial_tasklet_with_map")
sdfg.add_symbol("s", ty_)
sdfg.add_array("v", (N,), ty_)
st = sdfg.add_state()

tmp1_name, _ = sdfg.add_scalar(sdfg.temp_data_name(), ty_, transient=True)
tmp1_node = st.add_access(tmp1_name)
init_tasklet = st.add_tasklet("init", {}, {"out"}, "out = s")
st.add_edge(init_tasklet, "out", tmp1_node, None, dace.Memlet(tmp1_node.data))

me, mx = st.add_map("bcast", dict(i=f"0:{N}"))

copy_tasklet = st.add_tasklet("copy", {"inp"}, {"out"}, "out = inp")
st.add_memlet_path(tmp1_node, me, copy_tasklet, dst_conn="inp", memlet=dace.Memlet(f"{tmp1_node.data}[0]"))
tmp2_name, _ = sdfg.add_scalar(sdfg.temp_data_name(), ty_, transient=True)
tmp2_node = st.add_access(tmp2_name)
st.add_edge(copy_tasklet, "out", tmp2_node, None, dace.Memlet(tmp2_node.data))

bcast_tasklet = st.add_tasklet("bcast", {"inp"}, {"out"}, "out = inp")
st.add_edge(tmp2_node, None, bcast_tasklet, "inp", dace.Memlet(tmp2_node.data))
st.add_memlet_path(bcast_tasklet, mx, st.add_access("v"), src_conn="out", memlet=dace.Memlet("v[i]"))

sdfg.validate()
tasklet_nodes = {x for x in st.nodes() if isinstance(x, dace.nodes.Tasklet)}
assert tasklet_nodes == {init_tasklet, copy_tasklet, bcast_tasklet}

count = sdfg.apply_transformations_repeated(TrivialTaskletElimination)
assert count == 2

tasklet_nodes = {x for x in st.nodes() if isinstance(x, dace.nodes.Tasklet)}
assert tasklet_nodes == {init_tasklet}

assert len(st.in_edges(tmp2_node)) == 1
assert st.in_edges(tmp2_node)[0].src == me

assert len(st.out_edges(tmp2_node)) == 1
assert st.out_edges(tmp2_node)[0].dst == mx


def test_trivial_tasklet_with_implicit_cast():
ty32_ = dace.int32
ty64_ = dace.int64
sdfg = dace.SDFG("trivial_tasklet_with_implicit_cast")
sdfg.add_symbol("s", ty32_)
sdfg.add_array("v", (N,), ty32_)
st = sdfg.add_state()

tmp1_name, _ = sdfg.add_scalar(sdfg.temp_data_name(), ty32_, transient=True)
tmp1_node = st.add_access(tmp1_name)
init_tasklet = st.add_tasklet("init", {}, {"out"}, "out = s")
st.add_edge(init_tasklet, "out", tmp1_node, None, dace.Memlet(tmp1_node.data))

me, mx = st.add_map("bcast", dict(i=f"0:{N}"))

copy_tasklet = st.add_tasklet("copy", {"inp"}, {"out"}, "out = inp")
st.add_memlet_path(tmp1_node, me, copy_tasklet, dst_conn="inp", memlet=dace.Memlet(f"{tmp1_node.data}[0]"))
tmp2_name, _ = sdfg.add_scalar(sdfg.temp_data_name(), ty64_, transient=True)
tmp2_node = st.add_access(tmp2_name)
st.add_edge(copy_tasklet, "out", tmp2_node, None, dace.Memlet(tmp2_node.data))

bcast_tasklet = st.add_tasklet("bcast", {"inp"}, {"out"}, "out = inp")
st.add_edge(tmp2_node, None, bcast_tasklet, "inp", dace.Memlet(tmp2_node.data))
st.add_memlet_path(bcast_tasklet, mx, st.add_access("v"), src_conn="out", memlet=dace.Memlet("v[i]"))

sdfg.validate()
tasklet_nodes = {x for x in st.nodes() if isinstance(x, dace.nodes.Tasklet)}
assert tasklet_nodes == {init_tasklet, copy_tasklet, bcast_tasklet}

# not applied because of data types mismatch on read/write nodes
count = sdfg.apply_transformations_repeated(TrivialTaskletElimination)
assert count == 0


if __name__ == '__main__':
test_trivial_tasklet()
test_trivial_tasklet_with_map()
test_trivial_tasklet_with_implicit_cast()

0 comments on commit 975a065

Please sign in to comment.