diff --git a/dace/transformation/dataflow/trivial_tasklet_elimination.py b/dace/transformation/dataflow/trivial_tasklet_elimination.py index b4c23524e2..6a84959f7d 100644 --- a/dace/transformation/dataflow/trivial_tasklet_elimination.py +++ b/dace/transformation/dataflow/trivial_tasklet_elimination.py @@ -17,48 +17,62 @@ class TrivialTaskletElimination(transformation.SingleStateTransformation): """ read = transformation.PatternNode(nodes.AccessNode) + read_map = transformation.PatternNode(nodes.MapEntry) tasklet = transformation.PatternNode(nodes.Tasklet) write = transformation.PatternNode(nodes.AccessNode) + write_map = transformation.PatternNode(nodes.MapExit) @classmethod def expressions(cls): - return [sdutil.node_path_graph(cls.read, cls.tasklet, cls.write)] + return [ + sdutil.node_path_graph(cls.read, cls.tasklet, cls.write), + sdutil.node_path_graph(cls.read_map, cls.tasklet, cls.write), + sdutil.node_path_graph(cls.read, cls.tasklet, cls.write_map), + ] def can_be_applied(self, graph, expr_index, sdfg, permissive=False): - read = self.read + read = self.read_map if expr_index == 1 else self.read tasklet = self.tasklet - write = self.write - # Do not apply on Streams - if isinstance(sdfg.arrays[read.data], data.Stream): - return False - if isinstance(sdfg.arrays[write.data], data.Stream): + write = self.write_map if expr_index == 2 else self.write + if len(tasklet.in_connectors) != 1: return False if len(graph.in_edges(tasklet)) != 1: return False - if len(graph.out_edges(tasklet)) != 1: - return False - if graph.edges_between(tasklet, write)[0].data.wcr: - return False - if len(tasklet.in_connectors) != 1: - return False if len(tasklet.out_connectors) != 1: return False + if len(graph.out_edges(tasklet)) != 1: + return False in_conn = list(tasklet.in_connectors.keys())[0] out_conn = list(tasklet.out_connectors.keys())[0] if tasklet.code.as_string != f'{out_conn} = {in_conn}': return False - + read_memlet = graph.edges_between(read, tasklet)[0].data + read_desc = sdfg.arrays[read_memlet.data] + write_memlet = graph.edges_between(tasklet, write)[0].data + if write_memlet.wcr: + return False + write_desc = sdfg.arrays[write_memlet.data] + # Do not apply on streams + if isinstance(read_desc, data.Stream): + return False + if isinstance(write_desc, data.Stream): + return False + # Keep copy-tasklet connected to map node if source and destination nodes + # have different data type (implicit type cast) + if expr_index != 0 and read_desc.dtype != write_desc.dtype: + return False + return True def apply(self, graph, sdfg): - read = self.read + read = self.read_map if self.expr_index == 1 else self.read tasklet = self.tasklet - write = self.write + write = self.write_map if self.expr_index == 2 else self.write in_edge = graph.edges_between(read, tasklet)[0] out_edge = graph.edges_between(tasklet, write)[0] graph.remove_edge(in_edge) graph.remove_edge(out_edge) out_edge.data.other_subset = in_edge.data.subset - graph.add_nedge(read, write, out_edge.data) + graph.add_edge(read, in_edge.src_conn, write, out_edge.dst_conn, out_edge.data) graph.remove_node(tasklet) diff --git a/tests/transformations/trivial_tasklet_elimination_test.py b/tests/transformations/trivial_tasklet_elimination_test.py new file mode 100644 index 0000000000..8f97b51b7e --- /dev/null +++ b/tests/transformations/trivial_tasklet_elimination_test.py @@ -0,0 +1,129 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +import dace +from dace.transformation.dataflow.trivial_tasklet_elimination import TrivialTaskletElimination + + +N = 10 + + +def test_trivial_tasklet(): + ty_ = dace.int32 + sdfg = dace.SDFG("trivial_tasklet") + sdfg.add_symbol("s", ty_) + sdfg.add_array("v", (N,), ty_) + st = sdfg.add_state() + + tmp1_name, _ = sdfg.add_scalar(sdfg.temp_data_name(), ty_, transient=True) + tmp1_node = st.add_access(tmp1_name) + init_tasklet = st.add_tasklet("init", {}, {"out"}, "out = s") + st.add_edge(init_tasklet, "out", tmp1_node, None, dace.Memlet(tmp1_node.data)) + + tmp2_name, _ = sdfg.add_scalar(sdfg.temp_data_name(), ty_, transient=True) + tmp2_node = st.add_access(tmp2_name) + copy_tasklet = st.add_tasklet("copy", {"inp"}, {"out"}, "out = inp") + st.add_edge(tmp1_node, None, copy_tasklet, "inp", dace.Memlet(tmp1_node.data)) + st.add_edge(copy_tasklet, "out", tmp2_node, None, dace.Memlet(tmp2_node.data)) + + bcast_tasklet, _, _ = st.add_mapped_tasklet( + "bcast", + dict(i=f"0:{N}"), + inputs={"inp": dace.Memlet(f"{tmp2_node.data}[0]")}, + input_nodes={tmp2_node.data: tmp2_node}, + code="out = inp", + outputs={"out": dace.Memlet("v[i]")}, + external_edges=True, + ) + + sdfg.validate() + tasklet_nodes = {x for x in st.nodes() if isinstance(x, dace.nodes.Tasklet)} + assert tasklet_nodes == {init_tasklet, copy_tasklet, bcast_tasklet} + + count = sdfg.apply_transformations_repeated(TrivialTaskletElimination) + assert count == 1 + + assert len(st.out_edges(tmp1_node)) == 1 + assert st.out_edges(tmp1_node)[0].dst == tmp2_node + + tasklet_nodes = {x for x in st.nodes() if isinstance(x, dace.nodes.Tasklet)} + assert tasklet_nodes == {init_tasklet, bcast_tasklet} + + +def test_trivial_tasklet_with_map(): + ty_ = dace.int32 + sdfg = dace.SDFG("trivial_tasklet_with_map") + sdfg.add_symbol("s", ty_) + sdfg.add_array("v", (N,), ty_) + st = sdfg.add_state() + + tmp1_name, _ = sdfg.add_scalar(sdfg.temp_data_name(), ty_, transient=True) + tmp1_node = st.add_access(tmp1_name) + init_tasklet = st.add_tasklet("init", {}, {"out"}, "out = s") + st.add_edge(init_tasklet, "out", tmp1_node, None, dace.Memlet(tmp1_node.data)) + + me, mx = st.add_map("bcast", dict(i=f"0:{N}")) + + copy_tasklet = st.add_tasklet("copy", {"inp"}, {"out"}, "out = inp") + st.add_memlet_path(tmp1_node, me, copy_tasklet, dst_conn="inp", memlet=dace.Memlet(f"{tmp1_node.data}[0]")) + tmp2_name, _ = sdfg.add_scalar(sdfg.temp_data_name(), ty_, transient=True) + tmp2_node = st.add_access(tmp2_name) + st.add_edge(copy_tasklet, "out", tmp2_node, None, dace.Memlet(tmp2_node.data)) + + bcast_tasklet = st.add_tasklet("bcast", {"inp"}, {"out"}, "out = inp") + st.add_edge(tmp2_node, None, bcast_tasklet, "inp", dace.Memlet(tmp2_node.data)) + st.add_memlet_path(bcast_tasklet, mx, st.add_access("v"), src_conn="out", memlet=dace.Memlet("v[i]")) + + sdfg.validate() + tasklet_nodes = {x for x in st.nodes() if isinstance(x, dace.nodes.Tasklet)} + assert tasklet_nodes == {init_tasklet, copy_tasklet, bcast_tasklet} + + count = sdfg.apply_transformations_repeated(TrivialTaskletElimination) + assert count == 2 + + tasklet_nodes = {x for x in st.nodes() if isinstance(x, dace.nodes.Tasklet)} + assert tasklet_nodes == {init_tasklet} + + assert len(st.in_edges(tmp2_node)) == 1 + assert st.in_edges(tmp2_node)[0].src == me + + assert len(st.out_edges(tmp2_node)) == 1 + assert st.out_edges(tmp2_node)[0].dst == mx + + +def test_trivial_tasklet_with_implicit_cast(): + ty32_ = dace.int32 + ty64_ = dace.int64 + sdfg = dace.SDFG("trivial_tasklet_with_implicit_cast") + sdfg.add_symbol("s", ty32_) + sdfg.add_array("v", (N,), ty32_) + st = sdfg.add_state() + + tmp1_name, _ = sdfg.add_scalar(sdfg.temp_data_name(), ty32_, transient=True) + tmp1_node = st.add_access(tmp1_name) + init_tasklet = st.add_tasklet("init", {}, {"out"}, "out = s") + st.add_edge(init_tasklet, "out", tmp1_node, None, dace.Memlet(tmp1_node.data)) + + me, mx = st.add_map("bcast", dict(i=f"0:{N}")) + + copy_tasklet = st.add_tasklet("copy", {"inp"}, {"out"}, "out = inp") + st.add_memlet_path(tmp1_node, me, copy_tasklet, dst_conn="inp", memlet=dace.Memlet(f"{tmp1_node.data}[0]")) + tmp2_name, _ = sdfg.add_scalar(sdfg.temp_data_name(), ty64_, transient=True) + tmp2_node = st.add_access(tmp2_name) + st.add_edge(copy_tasklet, "out", tmp2_node, None, dace.Memlet(tmp2_node.data)) + + bcast_tasklet = st.add_tasklet("bcast", {"inp"}, {"out"}, "out = inp") + st.add_edge(tmp2_node, None, bcast_tasklet, "inp", dace.Memlet(tmp2_node.data)) + st.add_memlet_path(bcast_tasklet, mx, st.add_access("v"), src_conn="out", memlet=dace.Memlet("v[i]")) + + sdfg.validate() + tasklet_nodes = {x for x in st.nodes() if isinstance(x, dace.nodes.Tasklet)} + assert tasklet_nodes == {init_tasklet, copy_tasklet, bcast_tasklet} + + # not applied because of data types mismatch on read/write nodes + count = sdfg.apply_transformations_repeated(TrivialTaskletElimination) + assert count == 0 + + +if __name__ == '__main__': + test_trivial_tasklet() + test_trivial_tasklet_with_map() + test_trivial_tasklet_with_implicit_cast()