diff --git a/dace/transformation/dataflow/redundant_array.py b/dace/transformation/dataflow/redundant_array.py index 1cffa1ed59..7b241ff9cd 100644 --- a/dace/transformation/dataflow/redundant_array.py +++ b/dace/transformation/dataflow/redundant_array.py @@ -230,7 +230,7 @@ def can_be_applied(self, graph: SDFGState, expr_index, sdfg, permissive=False): if not permissive: # Make sure the memlet covers the removed array - subset = copy.deepcopy(e1.data.subset) + subset = copy.deepcopy(a1_subset) subset.squeeze() shape = [sz for sz in in_desc.shape if sz != 1] if any(m != a for m, a in zip(subset.size(), shape)): @@ -456,6 +456,49 @@ def _make_view(self, sdfg: SDFG, graph: SDFGState, in_array: nodes.AccessNode, o in_array.add_out_connector('views', force=True) e1._src_conn = 'views' + + def _is_reshaping_memlet( + self, + graph: SDFGState, + edge: graph.MultiConnectorEdge, + ) -> bool: + """Test if Memlet between `input_node` and `output_node` is reshaping. + + A "reshaping Memlet" is a Memlet that changes the shape of a data container, + in the same way as `numpy.reshape()` does. + + :param graph: The graph (SDFGState) in which the connection is. + :param edge: The edge between them. + """ + # Reshaping can not be a reduction + if edge.data.wcr or edge.data.wcr_nonatomic: + return False + + # Reshaping needs to access nodes. + src_node = edge.src + dst_node = edge.dst + if not all(isinstance(node, nodes.AccessNode) for node in (src_node, dst_node)): + return False + + # Reshaping can only happen between arrays. + sdfg = graph.sdfg + src_desc = sdfg.arrays[src_node.data] + dst_desc = sdfg.arrays[dst_node.data] + if not all(isinstance(desc, data.Array) and not isinstance(desc, data.View) for desc in (src_desc, dst_desc)): + return False + + # Reshaping implies that the shape is different. + if dst_desc.shape == src_desc.shape: + return False + + # A reshaping Memlet must read the whole source array and write the whole destination array. + src_subset, dst_subset = _validate_subsets(edge, sdfg.arrays) + for subset, shape in zip([dst_subset, src_subset], [dst_desc.shape, src_desc.shape]): + if not all(sssize == arraysize for sssize, arraysize in zip(subset.size(), shape)): + return False + + return True + def apply(self, graph, sdfg): in_array = self.in_array out_array = self.out_array @@ -520,8 +563,23 @@ def apply(self, graph, sdfg): # 3. The memlet does not cover the removed array; or # 4. Dimensions are mismatching (all dimensions are popped); # create a view. - if reduction or len(a_dims_to_pop) == len(in_desc.shape) or any( - m != a for m, a in zip(a1_subset.size(), in_desc.shape)): + if ( + reduction + or len(a_dims_to_pop) == len(in_desc.shape) + or any(m != a for m, a in zip(a1_subset.size(), in_desc.shape)) + ): + self._make_view(sdfg, graph, in_array, out_array, e1, b_subset, b_dims_to_pop) + return in_array + + # TODO: Fix me. + # As described in [issue 1595](https://github.com/spcl/dace/issues/1595) the + # transformation is unable to handle certain cases of reshaping Memlets + # correctly and fixing this case has proven rather difficult. In a first + # attempt the case of reshaping Memlets was forbidden (in the + # `can_be_applied()` method), however, this caused other (useful) cases to + # fail. For that reason such Memlets are transformed to Views. + # This is a fix and it should be addressed. + if self._is_reshaping_memlet(graph=graph, edge=e1): self._make_view(sdfg, graph, in_array, out_array, e1, b_subset, b_dims_to_pop) return in_array @@ -547,6 +605,7 @@ def apply(self, graph, sdfg): compose_and_push_back(bset, aset, b_dims_to_pop, popped) except (ValueError, NotImplementedError): self._make_view(sdfg, graph, in_array, out_array, e1, b_subset, b_dims_to_pop) + print(f"CREATED VIEW(2): {in_array}") return in_array # 2. Iterate over the e2 edges and traverse the memlet tree diff --git a/tests/transformations/redundant_copy_test.py b/tests/transformations/redundant_copy_test.py index ecf25e07d4..2c753c6fc5 100644 --- a/tests/transformations/redundant_copy_test.py +++ b/tests/transformations/redundant_copy_test.py @@ -1,6 +1,7 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. import numpy as np import pytest +from typing import Tuple import dace from dace import nodes @@ -9,6 +10,124 @@ RedundantArrayCopyingIn) +def test_reshaping_with_redundant_arrays(): + def make_sdfg() -> Tuple[dace.SDFG, dace.nodes.AccessNode, dace.nodes.AccessNode, dace.nodes.AccessNode]: + sdfg = dace.SDFG("slicing_sdfg") + _, input_desc = sdfg.add_array( + "input", + shape=(6, 6, 6), + transient=False, + strides=None, + dtype=dace.float64, + ) + _, a_desc = sdfg.add_array( + "a", + shape=(6, 6, 6), + transient=True, + strides=None, + dtype=dace.float64, + ) + _, b_desc = sdfg.add_array( + "b", + shape=(36, 1, 6), + transient=True, + strides=None, + dtype=dace.float64, + ) + _, output_desc = sdfg.add_array( + "output", + shape=(36, 1, 6), + transient=False, + strides=None, + dtype=dace.float64, + ) + state = sdfg.add_state("state", is_start_block=True) + input_an = state.add_access("input") + a_an = state.add_access("a") + b_an = state.add_access("b") + output_an = state.add_access("output") + + state.add_edge( + input_an, + None, + a_an, + None, + dace.Memlet.from_array("input", input_desc), + ) + state.add_edge( + a_an, + None, + b_an, + None, + dace.Memlet.simple( + "a", + subset_str="0:6, 0:6, 0:6", + other_subset_str="0:36, 0, 0:6", + ) + ) + state.add_edge( + b_an, + None, + output_an, + None, + dace.Memlet.from_array("b", b_desc), + ) + sdfg.validate() + assert state.number_of_nodes() == 4 + assert len(sdfg.arrays) == 4 + return sdfg, a_an, b_an, output_an + + def apply_trafo( + sdfg: dace.SDFG, + in_array: dace.nodes.AccessNode, + out_array: dace.nodes.AccessNode, + will_not_apply: bool = False, + will_create_view: bool = False, + ) -> dace.SDFG: + trafo = RedundantArray() + + candidate = {type(trafo).in_array: in_array, type(trafo).out_array: out_array} + state = sdfg.start_block + state_id = sdfg.node_id(state) + initial_arrays = len(sdfg.arrays) + initial_access_nodes = state.number_of_nodes() + + trafo.setup_match(sdfg, sdfg.cfg_id, state_id, candidate, 0, override=True) + if trafo.can_be_applied(state, 0, sdfg): + ret = trafo.apply(state, sdfg) + if ret is not None: # A view was created + if will_create_view: + return sdfg + assert False, f"A view was created instead removing '{in_array.data}'." + sdfg.validate() + assert state.number_of_nodes() == initial_access_nodes - 1 + assert len(sdfg.arrays) == initial_arrays - 1 + assert in_array.data not in sdfg.arrays + return sdfg + + if will_not_apply: + return sdfg + assert False, "Could not apply the transformation." + + input_array = np.array(np.random.rand(6, 6, 6), dtype=np.float64, order='C') + ref = input_array.reshape((36, 1, 6)).copy() + output_step1 = np.zeros_like(ref) + output_step2 = np.zeros_like(ref) + + # The Memlet between `a` and `b` is a reshaping Memlet, that are not handled. + sdfg, a_an, b_an, output_an = make_sdfg() + sdfg = apply_trafo(sdfg, in_array=a_an, out_array=b_an, will_create_view=True) + + sdfg(input=input_array, output=output_step1) + assert np.all(ref == output_step1) + + # The Memlet between `b` and `output` is not reshaping, and thus `b` should be removed. + sdfg = apply_trafo(sdfg, in_array=b_an, out_array=output_an) + + sdfg(input=input_array, output=output_step2) + assert np.all(ref == output_step2) + + def test_out(): sdfg = dace.SDFG("test_redundant_copy_out") state = sdfg.add_state() @@ -331,6 +450,7 @@ def flip_and_flatten(a, b): if __name__ == '__main__': + test_slicing_with_redundant_arrays() test_in() test_out() test_out_success() diff --git a/tests/trivial_map_elimination_test.py b/tests/trivial_map_elimination_test.py index 9600dad640..52ab4c1557 100644 --- a/tests/trivial_map_elimination_test.py +++ b/tests/trivial_map_elimination_test.py @@ -160,7 +160,7 @@ def test_can_be_applied(self): count = graph.apply_transformations(TrivialMapElimination, validate=False, validate_all=False) graph.validate() - graph.view() + #graph.view() self.assertGreater(count, 0)