From 8981b815300183baba64bad11279deb9d860c907 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Fri, 7 Jul 2023 21:38:04 +0200 Subject: [PATCH 1/3] Treat strides and tasklet-code slices the same way. --- dace/codegen/targets/cpp.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/dace/codegen/targets/cpp.py b/dace/codegen/targets/cpp.py index 295bf21310..1ec383815d 100644 --- a/dace/codegen/targets/cpp.py +++ b/dace/codegen/targets/cpp.py @@ -1082,11 +1082,17 @@ def _subscript_expr(self, slicenode: ast.AST, target: str) -> symbolic.SymbolicT ] if isinstance(visited_slice, ast.Tuple): - if len(strides) != len(visited_slice.elts): + # If slice is multi-dimensional and writes to array with more than 1 elements, then: + # - Assume this is indirection (?) + # - Soft-squeeze the slice (remove unit-modes) to match the treatment of the strides above. + desc = self.sdfg.arrays[dname] + if isinstance(desc, data.Array) and data._prod(desc.shape) != 1: + elts = [e for i, e in enumerate(visited_slice.elts) if desc.shape[i] != 1] + if len(strides) != len(elts): raise SyntaxError('Invalid number of dimensions in expression (expected %d, ' - 'got %d)' % (len(strides), len(visited_slice.elts))) + 'got %d)' % (len(strides), len(elts))) - return sum(symbolic.pystr_to_symbolic(unparse(elt)) * s for elt, s in zip(visited_slice.elts, strides)) + return sum(symbolic.pystr_to_symbolic(unparse(elt)) * s for elt, s in zip(elts, strides)) if len(strides) != 1: raise SyntaxError('Missing dimensions in expression (expected %d, got one)' % len(strides)) From a68b23ee1bf29b6733cbbb88b78a67a8a4d79228 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Fri, 7 Jul 2023 21:38:18 +0200 Subject: [PATCH 2/3] Added test. --- tests/python_frontend/indirections_test.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/python_frontend/indirections_test.py b/tests/python_frontend/indirections_test.py index c59dffb922..fa6af21e4f 100644 --- a/tests/python_frontend/indirections_test.py +++ b/tests/python_frontend/indirections_test.py @@ -387,6 +387,27 @@ def test_spmv(): assert (np.allclose(y, ref)) +def test_indirection_size_1(): + + def compute_index(scal: dc.int32[5]): + result = 0 + with dace.tasklet: + s << scal + r >> result + r = s[1] + 1 - 1 + return result + + @dc.program + def tester(a: dc.float64[1, 2, 3], scal: dc.int32[5]): + ind = compute_index(scal) + a[0, ind, 0] = 1 + + arr = np.random.rand(1, 2, 3) + scal = np.array([1, 1, 1, 1, 1], dtype=np.int32) + tester(arr, scal) + assert arr[0, 1, 0] == 1 + + if __name__ == "__main__": test_indirection_scalar() test_indirection_scalar_assign() @@ -412,3 +433,4 @@ def test_spmv(): test_indirection_array_nested() test_indirection_array_nested_nsdfg() test_spmv() + test_indirection_size_1() From b61acccef0ac8d6ed831c9634e8e759210b766b9 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Mon, 10 Jul 2023 13:34:58 +0200 Subject: [PATCH 3/3] Fix for the case where dname was not defined. --- dace/codegen/targets/cpp.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/dace/codegen/targets/cpp.py b/dace/codegen/targets/cpp.py index 1ec383815d..afbc6fca12 100644 --- a/dace/codegen/targets/cpp.py +++ b/dace/codegen/targets/cpp.py @@ -1085,9 +1085,12 @@ def _subscript_expr(self, slicenode: ast.AST, target: str) -> symbolic.SymbolicT # If slice is multi-dimensional and writes to array with more than 1 elements, then: # - Assume this is indirection (?) # - Soft-squeeze the slice (remove unit-modes) to match the treatment of the strides above. - desc = self.sdfg.arrays[dname] - if isinstance(desc, data.Array) and data._prod(desc.shape) != 1: - elts = [e for i, e in enumerate(visited_slice.elts) if desc.shape[i] != 1] + if target not in self.constants: + desc = self.sdfg.arrays[dname] + if isinstance(desc, data.Array) and data._prod(desc.shape) != 1: + elts = [e for i, e in enumerate(visited_slice.elts) if desc.shape[i] != 1] + else: + elts = visited_slice.elts if len(strides) != len(elts): raise SyntaxError('Invalid number of dimensions in expression (expected %d, ' 'got %d)' % (len(strides), len(elts)))