Skip to content

Commit

Permalink
Merge pull request #1313 from spcl/fix-condition-test-list
Browse files Browse the repository at this point in the history
Fix for #1308
  • Loading branch information
tbennun authored Jul 12, 2023
2 parents 3b738f5 + 346cfde commit 221f980
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 0 deletions.
2 changes: 2 additions & 0 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2352,6 +2352,8 @@ def _visit_test(self, node: ast.Expr):
# Visit test-condition
if not is_test_simple:
parsed_node = self.visit(node)
if isinstance(parsed_node, (list, tuple)) and len(parsed_node) == 1:
parsed_node = parsed_node[0]
if isinstance(parsed_node, str) and parsed_node in self.sdfg.arrays:
datadesc = self.sdfg.arrays[parsed_node]
if isinstance(datadesc, data.Array):
Expand Down
2 changes: 2 additions & 0 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -4370,6 +4370,8 @@ def _datatype_converter(sdfg: SDFG, state: SDFGState, arg: UfuncInput, dtype: dt
'outputs': ['__out'],
'code': "__out = dace.{}(__inp)".format(dtype.to_string())
}
if dtype in (dace.bool, dace.bool_):
impl['code'] = "__out = dace.bool_(__inp)"
tasklet_params = _set_tasklet_params(impl, [arg])

# Visitor input only needed when `has_where == True`.
Expand Down
14 changes: 14 additions & 0 deletions tests/python_frontend/conditionals_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,19 @@ def if_return_chain(i: dace.int64):
assert if_return_chain(15)[0] == 4


def test_if_test_call():

@dace.program
def if_test_call(a, b):
if bool(a):
return a
else:
return b

assert if_test_call(0, 2)[0] == if_test_call.f(0, 2)
assert if_test_call(1, 2)[0] == if_test_call.f(1, 2)


if __name__ == "__main__":
test_simple_if()
test_call_if()
Expand All @@ -169,3 +182,4 @@ def if_return_chain(i: dace.int64):
test_call_while()
test_if_return_both()
test_if_return_chain()
test_if_test_call()

0 comments on commit 221f980

Please sign in to comment.