From 61c43941f5c4a732ac8a4b9b4f81bdedebf48d5a Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Wed, 12 Jul 2023 18:16:24 +0200 Subject: [PATCH 1/3] Unpack visited condition's test. --- dace/frontend/python/newast.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 52a6862083..2283b433bd 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -2352,6 +2352,8 @@ def _visit_test(self, node: ast.Expr): # Visit test-condition if not is_test_simple: parsed_node = self.visit(node) + if isinstance(parsed_node, (list, tuple)) and len(parsed_node) == 1: + parsed_node = parsed_node[0] if isinstance(parsed_node, str) and parsed_node in self.sdfg.arrays: datadesc = self.sdfg.arrays[parsed_node] if isinstance(datadesc, data.Array): From 77896f7db2d9c253a23260fb4fe428f1657c57f4 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Wed, 12 Jul 2023 18:21:59 +0200 Subject: [PATCH 2/3] Special case for code convering to bool. --- dace/frontend/python/replacements.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index 3586d40374..528aef1ec8 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -4370,6 +4370,8 @@ def _datatype_converter(sdfg: SDFG, state: SDFGState, arg: UfuncInput, dtype: dt 'outputs': ['__out'], 'code': "__out = dace.{}(__inp)".format(dtype.to_string()) } + if dtype in (dace.bool, dace.bool_): + impl['code'] = "__out = dace.bool_(__inp)" tasklet_params = _set_tasklet_params(impl, [arg]) # Visitor input only needed when `has_where == True`. From 346cfdee7eba3a6cb88ef4e0f870a316c6cac3b9 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Wed, 12 Jul 2023 18:25:31 +0200 Subject: [PATCH 3/3] Added test. --- tests/python_frontend/conditionals_test.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/python_frontend/conditionals_test.py b/tests/python_frontend/conditionals_test.py index 03058c7bf8..994a45ed80 100644 --- a/tests/python_frontend/conditionals_test.py +++ b/tests/python_frontend/conditionals_test.py @@ -161,6 +161,19 @@ def if_return_chain(i: dace.int64): assert if_return_chain(15)[0] == 4 +def test_if_test_call(): + + @dace.program + def if_test_call(a, b): + if bool(a): + return a + else: + return b + + assert if_test_call(0, 2)[0] == if_test_call.f(0, 2) + assert if_test_call(1, 2)[0] == if_test_call.f(1, 2) + + if __name__ == "__main__": test_simple_if() test_call_if() @@ -169,3 +182,4 @@ def if_return_chain(i: dace.int64): test_call_while() test_if_return_both() test_if_return_chain() + test_if_test_call()