Skip to content

Commit

Permalink
Merge branch 'master' into better_copy_to_map
Browse files Browse the repository at this point in the history
  • Loading branch information
philip-paul-mueller committed Oct 16, 2024
2 parents ffca7e0 + 073b613 commit 4d1e0f1
Show file tree
Hide file tree
Showing 38 changed files with 1,667 additions and 513 deletions.
37 changes: 21 additions & 16 deletions dace/codegen/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,13 @@ def as_cpp(self, codegen, symbols) -> str:
expr += elem.as_cpp(codegen, symbols)
# In a general block, emit transitions and assignments after each individual block or region.
if isinstance(elem, BasicCFBlock) or (isinstance(elem, RegionBlock) and elem.region):
cfg = elem.state.parent_graph if isinstance(elem, BasicCFBlock) else elem.region.parent_graph
if isinstance(elem, BasicCFBlock):
g_elem = elem.state
else:
g_elem = elem.region
cfg = g_elem.parent_graph
sdfg = cfg if isinstance(cfg, SDFG) else cfg.sdfg
out_edges = cfg.out_edges(elem.state) if isinstance(elem, BasicCFBlock) else cfg.out_edges(elem.region)
out_edges = cfg.out_edges(g_elem)
for j, e in enumerate(out_edges):
if e not in self.gotos_to_ignore:
# Skip gotos to immediate successors
Expand Down Expand Up @@ -532,26 +536,27 @@ def as_cpp(self, codegen, symbols) -> str:
expr = ''

if self.loop.update_statement and self.loop.init_statement and self.loop.loop_variable:
# Initialize to either "int i = 0" or "i = 0" depending on whether the type has been defined.
defined_vars = codegen.dispatcher.defined_vars
if not defined_vars.has(self.loop.loop_variable):
try:
init = f'{symbols[self.loop.loop_variable]} '
except KeyError:
init = 'auto '
symbols[self.loop.loop_variable] = None
init += unparse_interstate_edge(self.loop.init_statement.code[0], sdfg, codegen=codegen, symbols=symbols)
init = unparse_interstate_edge(self.loop.init_statement.code[0], sdfg, codegen=codegen, symbols=symbols)
init = init.strip(';')

update = unparse_interstate_edge(self.loop.update_statement.code[0], sdfg, codegen=codegen, symbols=symbols)
update = update.strip(';')

if self.loop.inverted:
expr += f'{init};\n'
expr += 'do {\n'
expr += _clean_loop_body(self.body.as_cpp(codegen, symbols))
expr += f'{update};\n'
expr += f'\n}} while({cond});\n'
if self.loop.update_before_condition:
expr += f'{init};\n'
expr += 'do {\n'
expr += _clean_loop_body(self.body.as_cpp(codegen, symbols))
expr += f'{update};\n'
expr += f'}} while({cond});\n'
else:
expr += f'{init};\n'
expr += 'while (1) {\n'
expr += _clean_loop_body(self.body.as_cpp(codegen, symbols))
expr += f'if (!({cond}))\n'
expr += 'break;\n'
expr += f'{update};\n'
expr += '}\n'
else:
expr += f'for ({init}; {cond}; {update}) {{\n'
expr += _clean_loop_body(self.body.as_cpp(codegen, symbols))
Expand Down
5 changes: 5 additions & 0 deletions dace/codegen/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,8 @@ def dispatch_copy(self, src_node: nodes.Node, dst_node: nodes.Node, edge: MultiC
cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, function_stream: CodeIOStream,
output_stream: CodeIOStream) -> None:
""" Dispatches a code generator for a memory copy operation. """
if edge.data.is_empty():
return
state = cfg.state(state_id)
target = self.get_copy_dispatcher(src_node, dst_node, edge, sdfg, state)
if target is None:
Expand All @@ -616,6 +618,9 @@ def dispatch_output_definition(self, src_node: nodes.Node, dst_node: nodes.Node,
"""
state = cfg.state(state_id)
target = self.get_copy_dispatcher(src_node, dst_node, edge, sdfg, state)
if target is None:
raise ValueError(
f'Could not dispatch copy code generator for {src_node} -> {dst_node} in state {state.label}')

# Dispatch
self._used_targets.add(target)
Expand Down
24 changes: 22 additions & 2 deletions dace/codegen/targets/framecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
from dace.codegen.prettycode import CodeIOStream
from dace.codegen.common import codeblock_to_cpp, sym2cpp
from dace.codegen.targets.target import TargetCodeGenerator
from dace.codegen.tools.type_inference import infer_expr_type
from dace.frontend.python import astutils
from dace.sdfg import SDFG, SDFGState, nodes
from dace.sdfg import scope as sdscope
from dace.sdfg import utils
from dace.sdfg.analysis import cfg as cfg_analysis
from dace.sdfg.state import ControlFlowRegion
from dace.transformation.passes.analysis import StateReachability
from dace.sdfg.state import ControlFlowRegion, LoopRegion
from dace.transformation.passes.analysis import StateReachability, loop_analysis


def _get_or_eval_sdfg_first_arg(func, sdfg):
Expand Down Expand Up @@ -916,6 +918,24 @@ def generate_code(self,
interstate_symbols.update(symbols)
global_symbols.update(symbols)

if isinstance(cfr, LoopRegion) and cfr.loop_variable is not None and cfr.init_statement is not None:
init_assignment = cfr.init_statement.code[0]
update_assignment = cfr.update_statement.code[0]
if isinstance(init_assignment, astutils.ast.Assign):
init_assignment = init_assignment.value
if isinstance(update_assignment, astutils.ast.Assign):
update_assignment = update_assignment.value
if not cfr.loop_variable in interstate_symbols:
l_end = loop_analysis.get_loop_end(cfr)
l_start = loop_analysis.get_init_assignment(cfr)
l_step = loop_analysis.get_loop_stride(cfr)
sym_type = dtypes.result_type_of(infer_expr_type(l_start, global_symbols),
infer_expr_type(l_step, global_symbols),
infer_expr_type(l_end, global_symbols))
interstate_symbols[cfr.loop_variable] = sym_type
if not cfr.loop_variable in global_symbols:
global_symbols[cfr.loop_variable] = interstate_symbols[cfr.loop_variable]

for isvarName, isvarType in interstate_symbols.items():
if isvarType is None:
raise TypeError(f'Type inference failed for symbol {isvarName}')
Expand Down
4 changes: 2 additions & 2 deletions dace/frontend/fortran/fortran_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,8 +536,8 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node,
for i, s in zip(all_indices, array.shape)])
smallsubset = subsets.Range([(0, s - 1, 1) for s in shape])

memlet = Memlet(f'{array_name}[{subset}]->{smallsubset}')
memlet2 = Memlet(f'{viewname}[{smallsubset}]->{subset}')
memlet = Memlet(f'{array_name}[{subset}]->[{smallsubset}]')
memlet2 = Memlet(f'{viewname}[{smallsubset}]->[{subset}]')
wv = None
rv = None
if local_name.name in read_names:
Expand Down
7 changes: 2 additions & 5 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2565,8 +2565,7 @@ def visit_If(self, node: ast.If):
self._on_block_added(cond_block)

if_body = ControlFlowRegion(cond_block.label + '_body', sdfg=self.sdfg)
cond_block.branches.append((CodeBlock(cond), if_body))
if_body.parent_graph = self.cfg_target
cond_block.add_branch(CodeBlock(cond), if_body)

# Visit recursively
self._recursive_visit(node.body, 'if', node.lineno, if_body, False)
Expand All @@ -2575,9 +2574,7 @@ def visit_If(self, node: ast.If):
if len(node.orelse) > 0:
else_body = ControlFlowRegion(f'{cond_block.label}_else_{node.orelse[0].lineno}',
sdfg=self.sdfg)
#cond_block.branches.append((CodeBlock(cond_else), else_body))
cond_block.branches.append((None, else_body))
else_body.parent_graph = self.cfg_target
cond_block.add_branch(None, else_body)
# Visit recursively
self._recursive_visit(node.orelse, 'else', node.lineno, else_body, False)

Expand Down
2 changes: 2 additions & 0 deletions dace/frontend/python/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,8 @@ def _parse(self, args, kwargs, simplify=None, save=False, validate=False) -> SDF
sdutils.inline_control_flow_regions(nsdfg)
sdfg.using_experimental_blocks = self.use_experimental_cfg_blocks

sdfg.reset_cfg_list()

# Apply simplification pass automatically
if not cached and (simplify == True or
(simplify is None and Config.get_bool('optimizer', 'automatic_simplification'))):
Expand Down
6 changes: 3 additions & 3 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def _numpy_flip(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str, axis
# acpy, _ = sdfg.add_temp_transient(desc.shape, desc.dtype, desc.storage)
# vnode = state.add_read(view)
# anode = state.add_read(acpy)
# state.add_edge(vnode, None, anode, None, Memlet(f'{view}[{sset}] -> {dset}'))
# state.add_edge(vnode, None, anode, None, Memlet(f'{view}[{sset}] -> [{dset}]'))

arr_copy, _ = sdfg.add_temp_transient_like(desc)
inpidx = ','.join([f'__i{i}' for i in range(ndim)])
Expand Down Expand Up @@ -3934,7 +3934,7 @@ def implement_ufunc_accumulate(visitor: ProgramVisitor, ast_node: ast.Call, sdfg
init_state = nested_sdfg.add_state(label="init")
r = init_state.add_read(inpconn)
w = init_state.add_write(outconn)
init_state.add_nedge(r, w, dace.Memlet("{a}[{i}] -> {oi}".format(a=inpconn, i='0', oi='0')))
init_state.add_nedge(r, w, dace.Memlet("{a}[{i}] -> [{oi}]".format(a=inpconn, i='0', oi='0')))

body_state = nested_sdfg.add_state(label="body")
r1 = body_state.add_read(inpconn)
Expand Down Expand Up @@ -4189,7 +4189,7 @@ def view(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str, dtype, type
find_new_name=True)

# Register view with DaCe program visitor
# NOTE: We do not create here a Memlet of the form `A[subset] -> osubset`
# NOTE: We do not create here a Memlet of the form `A[subset] -> [osubset]`
# because the View can be of a different dtype. Adding `other_subset` in
# such cases will trigger validation error.
pv.views[newarr] = (arr, Memlet.from_array(arr, desc))
Expand Down
45 changes: 32 additions & 13 deletions dace/memlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def __init__(self,
of use API. Must follow one of the following forms:
1. ``ARRAY``,
2. ``ARRAY[SUBSET]``,
3. ``ARRAY[SUBSET] -> OTHER_SUBSET``.
3. ``ARRAY[SUBSET] -> [OTHER_SUBSET]``,
4. ``[OTHER_SUBSET] -> ARRAY[SUBSET]``,
5. ``SRC_ARRAY[SRC_SUBSET] -> DST_ARRAY[DST_SUBSET]``.
:param data: Data descriptor name attached to this memlet.
:param subset: The subset to take from the data attached to the edge,
represented either as a string or a Subset object.
Expand Down Expand Up @@ -330,6 +332,10 @@ def _parse_from_subexpr(self, expr: str):
raise SyntaxError('Invalid memlet syntax "%s"' % expr)
return expr, None

# [subset] syntax
if expr.startswith('['):
return None, SubsetProperty.from_string(expr[1:-1])

# array[subset] syntax
arrname, subset_str = expr[:-1].split('[')
if not dtypes.validate_name(arrname):
Expand All @@ -342,27 +348,40 @@ def _parse_memlet_from_str(self, expr: str):
or the _data,_subset fields.
:param expr: A string expression of the this memlet, given as an ease
of use API. Must follow one of the following forms:
1. ``ARRAY``,
2. ``ARRAY[SUBSET]``,
3. ``ARRAY[SUBSET] -> OTHER_SUBSET``.
Note that modes 2 and 3 are deprecated and will leave
the memlet uninitialized until inserted into an SDFG.
of use API. Must follow one of the following forms:
1. ``ARRAY``,
2. ``ARRAY[SUBSET]``,
3. ``ARRAY[SUBSET] -> [OTHER_SUBSET]``,
4. ``[OTHER_SUBSET] -> ARRAY[SUBSET]``,
5. ``SRC_ARRAY[SRC_SUBSET] -> DST_ARRAY[DST_SUBSET]``.
Note that options 1-2 will leave the memlet uninitialized
until added into an SDFG.
"""
expr = expr.strip()
if '->' not in expr: # Options 1 and 2
self.data, self.subset = self._parse_from_subexpr(expr)
return

# Option 3
# Options 3-5
src_expr, dst_expr = expr.split('->')
src_expr = src_expr.strip()
dst_expr = dst_expr.strip()
if '[' not in src_expr and not dtypes.validate_name(src_expr):
raise SyntaxError('Expression without data name not yet allowed')

self.data, self.subset = self._parse_from_subexpr(src_expr)
self.other_subset = SubsetProperty.from_string(dst_expr)
src_data, src_subset = self._parse_from_subexpr(src_expr)
dst_data, dst_subset = self._parse_from_subexpr(dst_expr)
if src_data is None and dst_data is None:
raise SyntaxError('At least one data name needs to be given')

if src_data is not None: # Prefer src[subset] -> [other_subset]
self.data = src_data
self.subset = src_subset
self.other_subset = dst_subset
self._is_data_src = True
else:
self.data = dst_data
self.subset = dst_subset
self.other_subset = src_subset
self._is_data_src = False

def try_initialize(self, sdfg: 'dace.sdfg.SDFG', state: 'dace.sdfg.SDFGState',
edge: 'dace.sdfg.graph.MultiConnectorEdge'):
Expand Down Expand Up @@ -660,7 +679,7 @@ def _label(self, shape):

if self.other_subset is not None:
if self._is_data_src is False:
result += ' <- [%s]' % str(self.other_subset)
result = f'[{self.other_subset}] -> {result}'
else:
result += ' -> [%s]' % str(self.other_subset)
return result
Expand Down
15 changes: 11 additions & 4 deletions dace/sdfg/analysis/schedule_tree/treenodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,17 @@ def as_string(self, indent: int = 0):
loop = self.header.loop
if loop.update_statement and loop.init_statement and loop.loop_variable:
if loop.inverted:
pre_header = indent * INDENTATION + f'{loop.init_statement.as_string}\n'
header = indent * INDENTATION + 'do:\n'
pre_footer = (indent + 1) * INDENTATION + f'{loop.update_statement.as_string}\n'
footer = indent * INDENTATION + f'while {loop.loop_condition.as_string}'
if loop.update_before_condition:
pre_header = indent * INDENTATION + f'{loop.init_statement.as_string}\n'
header = indent * INDENTATION + 'do:\n'
pre_footer = (indent + 1) * INDENTATION + f'{loop.update_statement.as_string}\n'
footer = indent * INDENTATION + f'while {loop.loop_condition.as_string}'
else:
pre_header = indent * INDENTATION + f'{loop.init_statement.as_string}\n'
header = indent * INDENTATION + 'while True:\n'
pre_footer = (indent + 1) * INDENTATION + f'if (not {loop.loop_condition.as_string}):\n'
pre_footer += (indent + 2) * INDENTATION + 'break\n'
footer = (indent + 1) * INDENTATION + f'{loop.update_statement.as_string}\n'
return pre_header + header + super().as_string(indent) + '\n' + pre_footer + footer
else:
result = (indent * INDENTATION +
Expand Down
Loading

0 comments on commit 4d1e0f1

Please sign in to comment.