From d0106207d33972e7c89c2d96e89ab98a182deb5b Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Fri, 18 Oct 2024 14:16:37 +0200 Subject: [PATCH] More fixes --- dace/codegen/control_flow.py | 4 +- .../analysis/schedule_tree/sdfg_to_tree.py | 2 +- dace/sdfg/replace.py | 12 +++ dace/sdfg/state.py | 4 +- dace/transformation/helpers.py | 98 +++++++++++-------- dace/transformation/pass_pipeline.py | 6 +- .../passes/analysis/analysis.py | 17 +++- .../passes/dead_state_elimination.py | 2 +- .../simplification/control_flow_raising.py | 1 + .../prune_empty_conditional_branches.py | 9 +- dace/transformation/passes/simplify.py | 5 + tests/fortran/array_test.py | 8 +- tests/passes/dead_code_elimination_test.py | 4 +- .../prune_empty_conditional_branches_test.py | 4 +- .../python_frontend/function_regions_test.py | 23 ++--- tests/python_frontend/named_region_test.py | 25 ++--- tests/schedule_tree/nesting_test.py | 7 +- tests/schedule_tree/schedule_test.py | 2 + 18 files changed, 151 insertions(+), 82 deletions(-) diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index 6657d09808..fdba40526d 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -62,8 +62,8 @@ import sympy as sp from dace import dtypes from dace.sdfg.analysis import cfg as cfg_analysis -from dace.sdfg.state import (BreakBlock, ConditionalBlock, ContinueBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion, - ReturnBlock, SDFGState) +from dace.sdfg.state import (BreakBlock, ConditionalBlock, ContinueBlock, ControlFlowBlock, ControlFlowRegion, + LoopRegion, ReturnBlock, SDFGState) from dace.sdfg.sdfg import SDFG, InterstateEdge from dace.sdfg.graph import Edge from dace.properties import CodeBlock diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 9357ca3db9..84f36189b3 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -652,7 +652,7 @@ def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) ############################# # Create initial tree from CFG - cfg: cf.ControlFlow = cf.structured_control_flow_tree(sdfg, lambda _: '') + cfg: cf.ControlFlow = cf.structured_control_flow_tree_with_regions(sdfg, lambda _: '') # Traverse said tree (also into states) to create the schedule tree def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.ScheduleTreeNode]: diff --git a/dace/sdfg/replace.py b/dace/sdfg/replace.py index 9b6086098e..83c5e5c148 100644 --- a/dace/sdfg/replace.py +++ b/dace/sdfg/replace.py @@ -11,6 +11,7 @@ from dace import dtypes, properties, symbolic from dace.codegen import cppunparse from dace.frontend.python.astutils import ASTFindReplace +from dace.sdfg.state import ConditionalBlock, LoopRegion if TYPE_CHECKING: from dace.sdfg.state import StateSubgraphView @@ -200,3 +201,14 @@ def replace_datadesc_names(sdfg: 'dace.SDFG', repl: Dict[str, str]): for edge in block.edges(): if edge.data.data in repl: edge.data.data = repl[edge.data.data] + + # Replace in loop or branch conditions: + if isinstance(cf, LoopRegion): + replace_in_codeblock(cf.loop_condition, repl) + if cf.update_statement: + replace_in_codeblock(cf.update_statement, repl) + if cf.init_statement: + replace_in_codeblock(cf.init_statement, repl) + elif isinstance(cf, ConditionalBlock): + for c, _ in cf.branches: + replace_in_codeblock(c, repl) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 083513005f..ca733258df 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -13,7 +13,6 @@ import dace from dace.frontend.python import astutils -from dace.sdfg.replace import replace_in_codeblock import dace.serialize from dace import data as dt from dace import dtypes @@ -3319,6 +3318,9 @@ def replace_dict(self, symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None, replace_in_graph: bool = True, replace_keys: bool = True): + # Avoid circular imports + from dace.sdfg.replace import replace_in_codeblock + if replace_keys: from dace.sdfg.replace import replace_properties_dict replace_properties_dict(self, repl, symrepl) diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index dc84bd4478..c6a701bd48 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -4,10 +4,10 @@ import itertools from networkx import MultiDiGraph -from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion +from dace.sdfg.state import AbstractControlFlowRegion, ConditionalBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion from dace.subsets import Range, Subset, union import dace.subsets as subsets -from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set, Union +from typing import Dict, Iterable, List, Optional, Tuple, Set, Union from dace import data, dtypes, symbolic from dace.codegen import control_flow as cf @@ -30,10 +30,13 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS """ # Nest states - states = subgraph.nodes() + blocks: List[ControlFlowBlock] = subgraph.nodes() return_state = None - if len(states) > 1: + if len(blocks) > 1: + # Avoid cyclic imports + from dace.transformation.passes.analysis import loop_analysis + graph: ControlFlowRegion = blocks[0].parent_graph if start is not None: source_node = start else: @@ -48,6 +51,22 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS raise NotImplementedError sink_node = sink_nodes[0] + all_blocks: List[ControlFlowBlock] = [] + is_edges: List[Edge[InterstateEdge]] = [] + for b in blocks: + if isinstance(b, AbstractControlFlowRegion): + for nb in b.all_control_flow_blocks(): + all_blocks.append(nb) + for e in b.all_interstate_edges(): + is_edges.append(e) + else: + all_blocks.append(b) + states: List[SDFGState] = [b for b in all_blocks if isinstance(b, SDFGState)] + for src in blocks: + for dst in blocks: + for edge in graph.edges_between(src, dst): + is_edges.append(edge) + # Find read/write sets read_set, write_set = set(), set() for state in states: @@ -67,12 +86,10 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS if e.data.data and e.data.data in sdfg.arrays: write_set.add(e.data.data) # Add data from edges - for src in states: - for dst in states: - for edge in sdfg.edges_between(src, dst): - for s in edge.data.free_symbols: - if s in sdfg.arrays: - read_set.add(s) + for edge in is_edges: + for s in edge.data.free_symbols: + if s in sdfg.arrays: + read_set.add(s) # Find NestedSDFG's unique data rw_set = read_set | write_set @@ -82,7 +99,7 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS continue found = False for state in sdfg.states(): - if state in states: + if state in blocks: continue for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.data == name): @@ -98,7 +115,7 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS # Find defined subgraph symbols defined_symbols = set() strictly_defined_symbols = set() - for e in subgraph.edges(): + for e in is_edges: defined_symbols.update(set(e.data.assignments.keys())) for k, v in e.data.assignments.items(): try: @@ -107,22 +124,30 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS except AttributeError: # `symbolic.pystr_to_symbolic` may return bool, which doesn't have attribute `args` pass - - return_state = new_state = sdfg.add_state('nested_sdfg_parent') + for b in all_blocks: + if isinstance(b, LoopRegion) and b.loop_variable is not None and b.loop_variable != '': + defined_symbols.update(b.loop_variable) + if b.loop_variable not in sdfg.symbols: + if b.init_statement: + init_assignment = loop_analysis.get_init_assignment(b) + if b.loop_variable not in {str(s) for s in symbolic.pystr_to_symbolic(init_assignment).args}: + strictly_defined_symbols.add(b.loop_variable) + else: + strictly_defined_symbols.add(b.loop_variable) + + return_state = new_state = graph.add_state('nested_sdfg_parent') nsdfg = SDFG("nested_sdfg", constants=sdfg.constants_prop, parent=new_state) nsdfg.add_node(source_node, is_start_state=True) - nsdfg.add_nodes_from([s for s in states if s is not source_node]) - for s in states: - s.parent = nsdfg + nsdfg.add_nodes_from([s for s in blocks if s is not source_node]) for e in subgraph.edges(): nsdfg.add_edge(e.src, e.dst, e.data) - for e in sdfg.in_edges(source_node): - sdfg.add_edge(e.src, new_state, e.data) - for e in sdfg.out_edges(sink_node): - sdfg.add_edge(new_state, e.dst, e.data) + for e in graph.in_edges(source_node): + graph.add_edge(e.src, new_state, e.data) + for e in graph.out_edges(sink_node): + graph.add_edge(new_state, e.dst, e.data) - sdfg.remove_nodes_from(states) + graph.remove_nodes_from(blocks) # Add NestedSDFG arrays for name in read_set | write_set: @@ -177,15 +202,15 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS # Part (2) if out_state is not None: - extra_state = sdfg.add_state('symbolic_output') - for e in sdfg.out_edges(new_state): - sdfg.add_edge(extra_state, e.dst, e.data) - sdfg.remove_edge(e) - sdfg.add_edge(new_state, extra_state, InterstateEdge(assignments=out_mapping)) + extra_state = graph.add_state('symbolic_output') + for e in graph.out_edges(new_state): + graph.add_edge(extra_state, e.dst, e.data) + graph.remove_edge(e) + graph.add_edge(new_state, extra_state, InterstateEdge(assignments=out_mapping)) new_state = extra_state else: - return_state = states[0] + return_state = blocks[0] return return_state @@ -244,7 +269,8 @@ def _copy_state(sdfg: SDFG, return state_copy -def find_sdfg_control_flow(cfg: ControlFlowRegion) -> Dict[ControlFlowBlock, Set[ControlFlowBlock]]: +def find_sdfg_control_flow(cfg: ControlFlowRegion) -> Dict[ControlFlowBlock, + Tuple[Set[ControlFlowBlock], ControlFlowBlock]]: """ Partitions a CFG to subgraphs that can be nested independently of each other. The method does not nest the subgraphs but alters the graph; (1) interstate edges are split, (2) scope source/sink nodes that belong to multiple @@ -352,16 +378,10 @@ def nest_sdfg_control_flow(sdfg: SDFG, components=None): :param sdfg: The SDFG to be partitioned. :param components: An existing partition of the SDFG. """ - - components = components or find_sdfg_control_flow(sdfg) - - num_components = len(components) - - if num_components < 2: - return - - for i, (start, (component, _)) in enumerate(components.items()): - nest_sdfg_subgraph(sdfg, graph.SubgraphView(sdfg, component), start) + regions = list(sdfg.all_control_flow_regions()) + for region in regions: + nest_sdfg_subgraph(region.sdfg, SubgraphView(region.sdfg, [region]), region) + sdfg.reset_cfg_list() def nest_state_subgraph(sdfg: SDFG, diff --git a/dace/transformation/pass_pipeline.py b/dace/transformation/pass_pipeline.py index d8bd8745ff..e558ab0b20 100644 --- a/dace/transformation/pass_pipeline.py +++ b/dace/transformation/pass_pipeline.py @@ -274,6 +274,10 @@ class ControlFlowRegionPass(Pass): CATEGORY: str = 'Helper' + apply_to_conditionals = properties.Property(dtype=bool, default=False, + desc='Whether or not to apply to conditional blocks. If false, do ' + + 'not apply to conditional blocks, but only their children.') + def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Dict[int, Optional[Any]]]: """ Applies the pass to control flow regions of the given SDFG by calling ``apply`` on each region. @@ -287,7 +291,7 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[D """ result = {} for region in sdfg.all_control_flow_regions(recursive=True, parent_first=False): - if isinstance(region, ConditionalBlock): + if isinstance(region, ConditionalBlock) and not self.apply_to_conditionals: continue retval = self.apply(region, pipeline_results) if retval is not None: diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index bc1bac4640..720b0b8b5b 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -153,7 +153,12 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Set[ # The implementation below is faster # tc: nx.DiGraph = nx.transitive_closure(sdfg.nx) for n, v in reachable_nodes(cfg.nx): - single_level_reachable[cfg.cfg_id][n] = set(v) + reach = set() + for nd in v: + reach.add(nd) + if isinstance(nd, AbstractControlFlowRegion): + reach.update(nd.all_control_flow_blocks()) + single_level_reachable[cfg.cfg_id][n] = reach if isinstance(cfg, LoopRegion): single_level_reachable[cfg.cfg_id][n].update(cfg.nodes()) @@ -166,7 +171,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Set[ result: Dict[ControlFlowBlock, Set[ControlFlowBlock]] = defaultdict(set) for block in cfg.nodes(): for reached in single_level_reachable[block.parent_graph.cfg_id][block]: - if isinstance(reached, ControlFlowRegion): + if isinstance(reached, AbstractControlFlowRegion): result[block].update(reached.all_control_flow_blocks()) result[block].add(reached) if block.parent_graph is not sdfg: @@ -516,7 +521,7 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & ppl.Modifies.States def depends_on(self): - return {AccessSets, FindAccessNodes, StateReachability} + return {AccessSets, FindAccessNodes, ControlFlowBlockReachability} def _find_dominating_write(self, desc: str, @@ -615,7 +620,9 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i access_nodes: Dict[str, Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]]] = pipeline_results[ FindAccessNodes.__name__][sdfg.cfg_id] - state_reach: Dict[SDFGState, Set[SDFGState]] = pipeline_results[StateReachability.__name__][sdfg.cfg_id] + block_reach: Dict[ControlFlowBlock, Set[ControlFlowBlock]] = pipeline_results[ + ControlFlowBlockReachability.__name__ + ] anames = sdfg.arrays.keys() for desc in sdfg.arrays: @@ -657,7 +664,7 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i continue write_state, write_node = write dominators = all_doms_transitive[write_state] - reach = state_reach[write_state] + reach = block_reach[write_state.parent_graph.cfg_id][write_state] for other_write, other_accesses in result[desc].items(): if other_write is not None and other_write[1] is write_node and other_write[0] is write_state: continue diff --git a/dace/transformation/passes/dead_state_elimination.py b/dace/transformation/passes/dead_state_elimination.py index cda193f43a..23f2a785f5 100644 --- a/dace/transformation/passes/dead_state_elimination.py +++ b/dace/transformation/passes/dead_state_elimination.py @@ -77,7 +77,7 @@ def apply_pass(self, sdfg: SDFG, _) -> Optional[Set[Union[SDFGState, Edge[Inters cfg.remove_node(node) else: result.add(node) - cfg.remove_node(block) + cfg.remove_node(node) if not annotated: return result or None diff --git a/dace/transformation/passes/simplification/control_flow_raising.py b/dace/transformation/passes/simplification/control_flow_raising.py index fa1a3c6f97..b852b798b1 100644 --- a/dace/transformation/passes/simplification/control_flow_raising.py +++ b/dace/transformation/passes/simplification/control_flow_raising.py @@ -97,4 +97,5 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Optional[Tuple[int, int]]: lifted_branches += self._lift_conditionals(sdfg) if lifted_branches == 0 and lifted_loops == 0: return None + top_sdfg.reset_cfg_list() return lifted_loops, lifted_branches diff --git a/dace/transformation/passes/simplification/prune_empty_conditional_branches.py b/dace/transformation/passes/simplification/prune_empty_conditional_branches.py index 111944614c..29a400d5d1 100644 --- a/dace/transformation/passes/simplification/prune_empty_conditional_branches.py +++ b/dace/transformation/passes/simplification/prune_empty_conditional_branches.py @@ -14,6 +14,10 @@ class PruneEmptyConditionalBranches(ppl.ControlFlowRegionPass): CATEGORY: str = 'Simplification' + def __init__(self): + super().__init__() + self.apply_to_conditionals = True + def modifies(self) -> ppl.Modifies: return ppl.Modifies.CFG @@ -58,5 +62,8 @@ def apply(self, region: ControlFlowRegion, _) -> Optional[int]: region.parent_graph.add_edge(replacement_node_before, replacement_node_after, InterstateEdge()) region.parent_graph.remove_node(region) - return removed_branches if removed_branches > 0 else None + if removed_branches > 0: + region.reset_cfg_list() + return removed_branches + return None diff --git a/dace/transformation/passes/simplify.py b/dace/transformation/passes/simplify.py index d3e8b580da..97eb383764 100644 --- a/dace/transformation/passes/simplify.py +++ b/dace/transformation/passes/simplify.py @@ -73,6 +73,8 @@ def __init__(self, validate_all: bool = False, skip: Optional[Set[str]] = None, verbose: bool = False, + no_inline_function_call_regions: bool = False, + no_inline_named_regions: bool = False, pass_options: Optional[Dict[str, Any]] = None): if skip: passes: List[ppl.Pass] = [p() for p in SIMPLIFY_PASSES if p.__name__ not in skip] @@ -88,6 +90,9 @@ def __init__(self, else: self.verbose = verbose + self.no_inline_function_call_regions = no_inline_function_call_regions + self.no_inline_named_regions = no_inline_named_regions + pass_opts = { 'no_inline_function_call_regions': self.no_inline_function_call_regions, 'no_inline_named_regions': self.no_inline_named_regions, diff --git a/tests/fortran/array_test.py b/tests/fortran/array_test.py index a8ece680a6..3283d2e37f 100644 --- a/tests/fortran/array_test.py +++ b/tests/fortran/array_test.py @@ -17,6 +17,7 @@ import dace.frontend.fortran.ast_transforms as ast_transforms import dace.frontend.fortran.ast_utils as ast_utils import dace.frontend.fortran.ast_internal_classes as ast_internal_classes +from dace.sdfg.state import LoopRegion def test_fortran_frontend_array_access(): @@ -199,9 +200,10 @@ def test_fortran_frontend_memlet_in_map_test(): """ sdfg = fortran_parser.create_sdfg_from_string(test_string, "memlet_range_test") sdfg.simplify() - # Expect that start is begin of for loop -> only one out edge to guard defining iterator variable - assert len(sdfg.out_edges(sdfg.start_state)) == 1 - iter_var = symbolic.symbol(list(sdfg.out_edges(sdfg.start_state)[0].data.assignments.keys())[0]) + # Expect that the start block is a loop + loop = sdfg.nodes()[0] + assert isinstance(loop, LoopRegion) + iter_var = symbolic.pystr_to_symbolic(loop.loop_variable) for state in sdfg.states(): if len(state.nodes()) > 1: diff --git a/tests/passes/dead_code_elimination_test.py b/tests/passes/dead_code_elimination_test.py index 14d380c463..bf40ff4409 100644 --- a/tests/passes/dead_code_elimination_test.py +++ b/tests/passes/dead_code_elimination_test.py @@ -265,12 +265,12 @@ def dce_tester(a: dace.float64[20], b: dace.float64[20]): sdfg = dce_tester.to_sdfg(simplify=False) result = Pipeline([DeadDataflowElimination(), DeadStateElimination()]).apply_pass(sdfg, {}) sdfg.simplify() - assert sdfg.number_of_nodes() <= 6 + assert sdfg.number_of_nodes() <= 4 # Check that arrays were removed assert all('c' not in [n.data for n in state.data_nodes()] for state in sdfg.nodes()) assert any('f' in [n.data for n in rstate if isinstance(n, dace.nodes.AccessNode)] - for rstate in result['DeadDataflowElimination'].values()) + for rstate in result[DeadDataflowElimination.__name__][0].values()) def test_dce_callback(): diff --git a/tests/passes/simplification/prune_empty_conditional_branches_test.py b/tests/passes/simplification/prune_empty_conditional_branches_test.py index 65463ad3a7..dc25cdc670 100644 --- a/tests/passes/simplification/prune_empty_conditional_branches_test.py +++ b/tests/passes/simplification/prune_empty_conditional_branches_test.py @@ -36,7 +36,7 @@ def prune_empty_else(A: dace.int32[N]): res = PruneEmptyConditionalBranches().apply_pass(sdfg, {}) - assert res[conditional] == 1 + assert res[conditional.cfg_id] == 1 assert len(conditional.branches) == 1 N1 = 32 @@ -82,7 +82,7 @@ def prune_empty_if_with_else(A: dace.int32[N]): res = PruneEmptyConditionalBranches().apply_pass(sdfg, {}) - assert res[conditional] == 1 + assert res[conditional.cfg_id] == 1 assert len(conditional.branches) == 1 assert conditional.branches[0][0] is not None diff --git a/tests/python_frontend/function_regions_test.py b/tests/python_frontend/function_regions_test.py index c5c9b4ac6f..5d5082a92e 100644 --- a/tests/python_frontend/function_regions_test.py +++ b/tests/python_frontend/function_regions_test.py @@ -3,6 +3,7 @@ import numpy as np import dace from dace.sdfg.state import FunctionCallRegion +from dace.transformation.passes.simplify import SimplifyPass def test_function_call(): N = dace.symbol("N") @@ -11,9 +12,9 @@ def func(A: dace.float64[N]): @dace.program def prog(I: dace.float64[N]): return func(I) - prog.use_experimental_cfg_blocks = True - sdfg = prog.to_sdfg() - call_region: FunctionCallRegion = sdfg.nodes()[1] + sdfg = prog.to_sdfg(simplify=False) + SimplifyPass(no_inline_function_call_regions=True, no_inline_named_regions=True).apply_pass(sdfg, {}) + call_region: FunctionCallRegion = sdfg.nodes()[0] assert call_region.arguments == {'A': 'I'} assert sdfg(np.array([+1], dtype=np.float64), N=1) == 15 assert sdfg(np.array([-1], dtype=np.float64), N=1) == 5 @@ -26,13 +27,13 @@ def func(A: dace.float64[N], B: dace.float64[N], C: dace.float64[N]): def prog(E: dace.float64[N], F: dace.float64[N], G: dace.float64[N]): func(A=E, B=F, C=G) func(A=G, B=E, C=E) - prog.use_experimental_cfg_blocks = True E = np.array([1]) F = np.array([2]) G = np.array([3]) - sdfg = prog.to_sdfg(E=E, F=F, G=G, N=1) - call1: FunctionCallRegion = sdfg.nodes()[1] - call2: FunctionCallRegion = sdfg.nodes()[2] + sdfg = prog.to_sdfg(E=E, F=F, G=G, N=1, simplify=False) + SimplifyPass(no_inline_function_call_regions=True, no_inline_named_regions=True).apply_pass(sdfg, {}) + call1: FunctionCallRegion = sdfg.nodes()[0] + call2: FunctionCallRegion = sdfg.nodes()[1] assert call1.arguments == {'A': 'E', 'B': 'F', 'C': 'G'} assert call2.arguments == {'A': 'G', 'B': 'E', 'C': 'E'} @@ -44,10 +45,10 @@ def func(A: dace.float64[N], B: dace.float64[N], C: dace.float64[N]): def prog(): func(A=np.array([1]), B=np.array([2]), C=np.array([3])) func(A=np.array([3]), B=np.array([1]), C=np.array([1])) - prog.use_experimental_cfg_blocks = True - sdfg = prog.to_sdfg(N=1) - call1: FunctionCallRegion = sdfg.nodes()[1] - call2: FunctionCallRegion = sdfg.nodes()[2] + sdfg = prog.to_sdfg(N=1, simplify=False) + SimplifyPass(no_inline_function_call_regions=True, no_inline_named_regions=True).apply_pass(sdfg, {}) + call1: FunctionCallRegion = sdfg.nodes()[0] + call2: FunctionCallRegion = sdfg.nodes()[1] assert call1.arguments == {'A': '__tmp0', 'B': '__tmp1', 'C': '__tmp2'} assert call2.arguments == {'A': '__tmp4', 'B': '__tmp5', 'C': '__tmp6'} diff --git a/tests/python_frontend/named_region_test.py b/tests/python_frontend/named_region_test.py index f9be206bca..593fde5c0f 100644 --- a/tests/python_frontend/named_region_test.py +++ b/tests/python_frontend/named_region_test.py @@ -3,6 +3,7 @@ import numpy as np import dace from dace.sdfg.state import NamedRegion +from dace.transformation.passes.simplify import SimplifyPass def test_named_region_no_name(): @@ -11,21 +12,21 @@ def func(A: dace.float64[1]): with dace.named: A[0] = 20 return A - func.use_experimental_cfg_blocks = True - sdfg = func.to_sdfg() - named_region = sdfg.reset_cfg_list()[1] + sdfg = func.to_sdfg(simplify=False) + SimplifyPass(no_inline_function_call_regions=True, no_inline_named_regions=True).apply_pass(sdfg, {}) + named_region = sdfg.nodes()[0] assert isinstance(named_region, NamedRegion) A = np.zeros(shape=(1,)) - assert func(A) == 20 + assert sdfg(A) == 20 def test_named_region_with_name(): @dace.program def func(): with dace.named("my named region"): pass - func.use_experimental_cfg_blocks = True - sdfg = func.to_sdfg() - named_region: NamedRegion = sdfg.reset_cfg_list()[1] + sdfg = func.to_sdfg(simplify=False) + SimplifyPass(no_inline_function_call_regions=True, no_inline_named_regions=True).apply_pass(sdfg, {}) + named_region: NamedRegion = sdfg.nodes()[0] assert named_region.label == "my named region" def test_nested_named_regions(): @@ -35,13 +36,13 @@ def func(): with dace.named("middle region"): with dace.named("inner region"): pass - func.use_experimental_cfg_blocks = True - sdfg = func.to_sdfg() - outer: NamedRegion = sdfg.nodes()[1] + sdfg = func.to_sdfg(simplify=False) + SimplifyPass(no_inline_function_call_regions=True, no_inline_named_regions=True).apply_pass(sdfg, {}) + outer: NamedRegion = sdfg.nodes()[0] assert outer.label == "outer region" - middle: NamedRegion = outer.nodes()[1] + middle: NamedRegion = outer.nodes()[0] assert middle.label == "middle region" - inner: NamedRegion = middle.nodes()[1] + inner: NamedRegion = middle.nodes()[0] assert inner.label == "inner region" if __name__ == "__main__": diff --git a/tests/schedule_tree/nesting_test.py b/tests/schedule_tree/nesting_test.py index 161f15d6c1..8361ecb149 100644 --- a/tests/schedule_tree/nesting_test.py +++ b/tests/schedule_tree/nesting_test.py @@ -5,6 +5,7 @@ import dace from dace.sdfg.analysis.schedule_tree import treenodes as tn from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree +from dace.sdfg.utils import inline_control_flow_regions from dace.transformation.dataflow import RemoveSliceView import pytest @@ -63,7 +64,8 @@ def tester(A: dace.float64[N, N]): if simplified: assert [type(n) - for n in stree.preorder_traversal()][1:] == [tn.MapScope, tn.MapScope, tn.ForScope, tn.TaskletNode] + for n in stree.preorder_traversal()][1:] == [tn.MapScope, tn.MapScope, tn.GeneralLoopScope, + tn.TaskletNode] tasklet: tn.TaskletNode = list(stree.preorder_traversal())[-1] @@ -127,6 +129,7 @@ def tester(a: dace.float64[40], b: dace.float64[40]): nester(b[1:21], a[10:30]) sdfg = tester.to_sdfg(simplify=False) + inline_control_flow_regions(sdfg) sdfg.apply_transformations_repeated(RemoveSliceView) stree = as_schedule_tree(sdfg) @@ -150,6 +153,7 @@ def tester(a: dace.float64[40]): nester(a[1:21], a[10:30]) sdfg = tester.to_sdfg(simplify=False) + inline_control_flow_regions(sdfg) sdfg.apply_transformations_repeated(RemoveSliceView) stree = as_schedule_tree(sdfg) @@ -176,6 +180,7 @@ def tester(a: dace.float64[N, N]): nester1(a[:, 1]) sdfg = tester.to_sdfg(simplify=simplify) + inline_control_flow_regions(sdfg) stree = as_schedule_tree(sdfg) # Simplifying yields a different SDFG due to views, so testing is slightly different diff --git a/tests/schedule_tree/schedule_test.py b/tests/schedule_tree/schedule_test.py index 1bf2962cb3..542ff425dc 100644 --- a/tests/schedule_tree/schedule_test.py +++ b/tests/schedule_tree/schedule_test.py @@ -5,6 +5,8 @@ from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree import numpy as np +from dace.sdfg.utils import inline_control_flow_regions + def test_for_in_map_in_for():