Skip to content

Commit

Permalink
Fixed RedundantArray's handling of "reshaping" Memlets (#1603)
Browse files Browse the repository at this point in the history
This PR fixes an issue that was reported as [issue
1595](#1595), that was traced back to
`RedundantArray`.

The commit adds a deterministic test, unlike the one in the issue, that
fails without the fix.
The underlying problem is, that the transformation does not correctly
handle a Memlet that performs a reshaping.
This commit does not really solves the issue, instead it adds a special
case for this particular case and then handles it correctly as I was
unable to modify the code to handle it correctly.
It is not a nice solution, but it works.
  • Loading branch information
philip-paul-mueller authored Jul 4, 2024
1 parent e3d980a commit b5f5624
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 4 deletions.
65 changes: 62 additions & 3 deletions dace/transformation/dataflow/redundant_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def can_be_applied(self, graph: SDFGState, expr_index, sdfg, permissive=False):

if not permissive:
# Make sure the memlet covers the removed array
subset = copy.deepcopy(e1.data.subset)
subset = copy.deepcopy(a1_subset)
subset.squeeze()
shape = [sz for sz in in_desc.shape if sz != 1]
if any(m != a for m, a in zip(subset.size(), shape)):
Expand Down Expand Up @@ -456,6 +456,49 @@ def _make_view(self, sdfg: SDFG, graph: SDFGState, in_array: nodes.AccessNode, o
in_array.add_out_connector('views', force=True)
e1._src_conn = 'views'


def _is_reshaping_memlet(
self,
graph: SDFGState,
edge: graph.MultiConnectorEdge,
) -> bool:
"""Test if Memlet between `input_node` and `output_node` is reshaping.
A "reshaping Memlet" is a Memlet that changes the shape of a data container,
in the same way as `numpy.reshape()` does.
:param graph: The graph (SDFGState) in which the connection is.
:param edge: The edge between them.
"""
# Reshaping can not be a reduction
if edge.data.wcr or edge.data.wcr_nonatomic:
return False

# Reshaping needs to access nodes.
src_node = edge.src
dst_node = edge.dst
if not all(isinstance(node, nodes.AccessNode) for node in (src_node, dst_node)):
return False

# Reshaping can only happen between arrays.
sdfg = graph.sdfg
src_desc = sdfg.arrays[src_node.data]
dst_desc = sdfg.arrays[dst_node.data]
if not all(isinstance(desc, data.Array) and not isinstance(desc, data.View) for desc in (src_desc, dst_desc)):
return False

# Reshaping implies that the shape is different.
if dst_desc.shape == src_desc.shape:
return False

# A reshaping Memlet must read the whole source array and write the whole destination array.
src_subset, dst_subset = _validate_subsets(edge, sdfg.arrays)
for subset, shape in zip([dst_subset, src_subset], [dst_desc.shape, src_desc.shape]):
if not all(sssize == arraysize for sssize, arraysize in zip(subset.size(), shape)):
return False

return True

def apply(self, graph, sdfg):
in_array = self.in_array
out_array = self.out_array
Expand Down Expand Up @@ -520,8 +563,23 @@ def apply(self, graph, sdfg):
# 3. The memlet does not cover the removed array; or
# 4. Dimensions are mismatching (all dimensions are popped);
# create a view.
if reduction or len(a_dims_to_pop) == len(in_desc.shape) or any(
m != a for m, a in zip(a1_subset.size(), in_desc.shape)):
if (
reduction
or len(a_dims_to_pop) == len(in_desc.shape)
or any(m != a for m, a in zip(a1_subset.size(), in_desc.shape))
):
self._make_view(sdfg, graph, in_array, out_array, e1, b_subset, b_dims_to_pop)
return in_array

# TODO: Fix me.
# As described in [issue 1595](https://github.com/spcl/dace/issues/1595) the
# transformation is unable to handle certain cases of reshaping Memlets
# correctly and fixing this case has proven rather difficult. In a first
# attempt the case of reshaping Memlets was forbidden (in the
# `can_be_applied()` method), however, this caused other (useful) cases to
# fail. For that reason such Memlets are transformed to Views.
# This is a fix and it should be addressed.
if self._is_reshaping_memlet(graph=graph, edge=e1):
self._make_view(sdfg, graph, in_array, out_array, e1, b_subset, b_dims_to_pop)
return in_array

Expand All @@ -547,6 +605,7 @@ def apply(self, graph, sdfg):
compose_and_push_back(bset, aset, b_dims_to_pop, popped)
except (ValueError, NotImplementedError):
self._make_view(sdfg, graph, in_array, out_array, e1, b_subset, b_dims_to_pop)
print(f"CREATED VIEW(2): {in_array}")
return in_array

# 2. Iterate over the e2 edges and traverse the memlet tree
Expand Down
120 changes: 120 additions & 0 deletions tests/transformations/redundant_copy_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
import numpy as np
import pytest
from typing import Tuple

import dace
from dace import nodes
Expand All @@ -9,6 +10,124 @@
RedundantArrayCopyingIn)


def test_reshaping_with_redundant_arrays():
def make_sdfg() -> Tuple[dace.SDFG, dace.nodes.AccessNode, dace.nodes.AccessNode, dace.nodes.AccessNode]:
sdfg = dace.SDFG("slicing_sdfg")
_, input_desc = sdfg.add_array(
"input",
shape=(6, 6, 6),
transient=False,
strides=None,
dtype=dace.float64,
)
_, a_desc = sdfg.add_array(
"a",
shape=(6, 6, 6),
transient=True,
strides=None,
dtype=dace.float64,
)
_, b_desc = sdfg.add_array(
"b",
shape=(36, 1, 6),
transient=True,
strides=None,
dtype=dace.float64,
)
_, output_desc = sdfg.add_array(
"output",
shape=(36, 1, 6),
transient=False,
strides=None,
dtype=dace.float64,
)
state = sdfg.add_state("state", is_start_block=True)
input_an = state.add_access("input")
a_an = state.add_access("a")
b_an = state.add_access("b")
output_an = state.add_access("output")

state.add_edge(
input_an,
None,
a_an,
None,
dace.Memlet.from_array("input", input_desc),
)
state.add_edge(
a_an,
None,
b_an,
None,
dace.Memlet.simple(
"a",
subset_str="0:6, 0:6, 0:6",
other_subset_str="0:36, 0, 0:6",
)
)
state.add_edge(
b_an,
None,
output_an,
None,
dace.Memlet.from_array("b", b_desc),
)
sdfg.validate()
assert state.number_of_nodes() == 4
assert len(sdfg.arrays) == 4
return sdfg, a_an, b_an, output_an

def apply_trafo(
sdfg: dace.SDFG,
in_array: dace.nodes.AccessNode,
out_array: dace.nodes.AccessNode,
will_not_apply: bool = False,
will_create_view: bool = False,
) -> dace.SDFG:
trafo = RedundantArray()

candidate = {type(trafo).in_array: in_array, type(trafo).out_array: out_array}
state = sdfg.start_block
state_id = sdfg.node_id(state)
initial_arrays = len(sdfg.arrays)
initial_access_nodes = state.number_of_nodes()

trafo.setup_match(sdfg, sdfg.cfg_id, state_id, candidate, 0, override=True)
if trafo.can_be_applied(state, 0, sdfg):
ret = trafo.apply(state, sdfg)
if ret is not None: # A view was created
if will_create_view:
return sdfg
assert False, f"A view was created instead removing '{in_array.data}'."
sdfg.validate()
assert state.number_of_nodes() == initial_access_nodes - 1
assert len(sdfg.arrays) == initial_arrays - 1
assert in_array.data not in sdfg.arrays
return sdfg

if will_not_apply:
return sdfg
assert False, "Could not apply the transformation."

input_array = np.array(np.random.rand(6, 6, 6), dtype=np.float64, order='C')
ref = input_array.reshape((36, 1, 6)).copy()
output_step1 = np.zeros_like(ref)
output_step2 = np.zeros_like(ref)

# The Memlet between `a` and `b` is a reshaping Memlet, that are not handled.
sdfg, a_an, b_an, output_an = make_sdfg()
sdfg = apply_trafo(sdfg, in_array=a_an, out_array=b_an, will_create_view=True)

sdfg(input=input_array, output=output_step1)
assert np.all(ref == output_step1)

# The Memlet between `b` and `output` is not reshaping, and thus `b` should be removed.
sdfg = apply_trafo(sdfg, in_array=b_an, out_array=output_an)

sdfg(input=input_array, output=output_step2)
assert np.all(ref == output_step2)


def test_out():
sdfg = dace.SDFG("test_redundant_copy_out")
state = sdfg.add_state()
Expand Down Expand Up @@ -331,6 +450,7 @@ def flip_and_flatten(a, b):


if __name__ == '__main__':
test_slicing_with_redundant_arrays()
test_in()
test_out()
test_out_success()
Expand Down
2 changes: 1 addition & 1 deletion tests/trivial_map_elimination_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def test_can_be_applied(self):

count = graph.apply_transformations(TrivialMapElimination, validate=False, validate_all=False)
graph.validate()
graph.view()
#graph.view()

self.assertGreater(count, 0)

Expand Down

0 comments on commit b5f5624

Please sign in to comment.