Skip to content

Commit

Permalink
More fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Oct 18, 2024
1 parent 4084dfe commit d010620
Show file tree
Hide file tree
Showing 18 changed files with 151 additions and 82 deletions.
4 changes: 2 additions & 2 deletions dace/codegen/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@
import sympy as sp
from dace import dtypes
from dace.sdfg.analysis import cfg as cfg_analysis
from dace.sdfg.state import (BreakBlock, ConditionalBlock, ContinueBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion,
ReturnBlock, SDFGState)
from dace.sdfg.state import (BreakBlock, ConditionalBlock, ContinueBlock, ControlFlowBlock, ControlFlowRegion,
LoopRegion, ReturnBlock, SDFGState)
from dace.sdfg.sdfg import SDFG, InterstateEdge
from dace.sdfg.graph import Edge
from dace.properties import CodeBlock
Expand Down
2 changes: 1 addition & 1 deletion dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True)
#############################

# Create initial tree from CFG
cfg: cf.ControlFlow = cf.structured_control_flow_tree(sdfg, lambda _: '')
cfg: cf.ControlFlow = cf.structured_control_flow_tree_with_regions(sdfg, lambda _: '')

# Traverse said tree (also into states) to create the schedule tree
def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.ScheduleTreeNode]:
Expand Down
12 changes: 12 additions & 0 deletions dace/sdfg/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dace import dtypes, properties, symbolic
from dace.codegen import cppunparse
from dace.frontend.python.astutils import ASTFindReplace
from dace.sdfg.state import ConditionalBlock, LoopRegion

if TYPE_CHECKING:
from dace.sdfg.state import StateSubgraphView
Expand Down Expand Up @@ -200,3 +201,14 @@ def replace_datadesc_names(sdfg: 'dace.SDFG', repl: Dict[str, str]):
for edge in block.edges():
if edge.data.data in repl:
edge.data.data = repl[edge.data.data]

# Replace in loop or branch conditions:
if isinstance(cf, LoopRegion):
replace_in_codeblock(cf.loop_condition, repl)
if cf.update_statement:
replace_in_codeblock(cf.update_statement, repl)
if cf.init_statement:
replace_in_codeblock(cf.init_statement, repl)
elif isinstance(cf, ConditionalBlock):
for c, _ in cf.branches:
replace_in_codeblock(c, repl)
4 changes: 3 additions & 1 deletion dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import dace
from dace.frontend.python import astutils
from dace.sdfg.replace import replace_in_codeblock
import dace.serialize
from dace import data as dt
from dace import dtypes
Expand Down Expand Up @@ -3319,6 +3318,9 @@ def replace_dict(self,
symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None,
replace_in_graph: bool = True,
replace_keys: bool = True):
# Avoid circular imports
from dace.sdfg.replace import replace_in_codeblock

if replace_keys:
from dace.sdfg.replace import replace_properties_dict
replace_properties_dict(self, repl, symrepl)
Expand Down
98 changes: 59 additions & 39 deletions dace/transformation/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import itertools
from networkx import MultiDiGraph

from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion
from dace.sdfg.state import AbstractControlFlowRegion, ConditionalBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion
from dace.subsets import Range, Subset, union
import dace.subsets as subsets
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set, Union
from typing import Dict, Iterable, List, Optional, Tuple, Set, Union

from dace import data, dtypes, symbolic
from dace.codegen import control_flow as cf
Expand All @@ -30,10 +30,13 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS
"""

# Nest states
states = subgraph.nodes()
blocks: List[ControlFlowBlock] = subgraph.nodes()
return_state = None
if len(states) > 1:
if len(blocks) > 1:
# Avoid cyclic imports
from dace.transformation.passes.analysis import loop_analysis

graph: ControlFlowRegion = blocks[0].parent_graph
if start is not None:
source_node = start
else:
Expand All @@ -48,6 +51,22 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS
raise NotImplementedError
sink_node = sink_nodes[0]

all_blocks: List[ControlFlowBlock] = []
is_edges: List[Edge[InterstateEdge]] = []
for b in blocks:
if isinstance(b, AbstractControlFlowRegion):
for nb in b.all_control_flow_blocks():
all_blocks.append(nb)
for e in b.all_interstate_edges():
is_edges.append(e)
else:
all_blocks.append(b)
states: List[SDFGState] = [b for b in all_blocks if isinstance(b, SDFGState)]
for src in blocks:
for dst in blocks:
for edge in graph.edges_between(src, dst):
is_edges.append(edge)

# Find read/write sets
read_set, write_set = set(), set()
for state in states:
Expand All @@ -67,12 +86,10 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS
if e.data.data and e.data.data in sdfg.arrays:
write_set.add(e.data.data)
# Add data from edges
for src in states:
for dst in states:
for edge in sdfg.edges_between(src, dst):
for s in edge.data.free_symbols:
if s in sdfg.arrays:
read_set.add(s)
for edge in is_edges:
for s in edge.data.free_symbols:
if s in sdfg.arrays:
read_set.add(s)

# Find NestedSDFG's unique data
rw_set = read_set | write_set
Expand All @@ -82,7 +99,7 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS
continue
found = False
for state in sdfg.states():
if state in states:
if state in blocks:
continue
for node in state.nodes():
if (isinstance(node, nodes.AccessNode) and node.data == name):
Expand All @@ -98,7 +115,7 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS
# Find defined subgraph symbols
defined_symbols = set()
strictly_defined_symbols = set()
for e in subgraph.edges():
for e in is_edges:
defined_symbols.update(set(e.data.assignments.keys()))
for k, v in e.data.assignments.items():
try:
Expand All @@ -107,22 +124,30 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS
except AttributeError:
# `symbolic.pystr_to_symbolic` may return bool, which doesn't have attribute `args`
pass

return_state = new_state = sdfg.add_state('nested_sdfg_parent')
for b in all_blocks:
if isinstance(b, LoopRegion) and b.loop_variable is not None and b.loop_variable != '':
defined_symbols.update(b.loop_variable)
if b.loop_variable not in sdfg.symbols:
if b.init_statement:
init_assignment = loop_analysis.get_init_assignment(b)
if b.loop_variable not in {str(s) for s in symbolic.pystr_to_symbolic(init_assignment).args}:
strictly_defined_symbols.add(b.loop_variable)
else:
strictly_defined_symbols.add(b.loop_variable)

return_state = new_state = graph.add_state('nested_sdfg_parent')
nsdfg = SDFG("nested_sdfg", constants=sdfg.constants_prop, parent=new_state)
nsdfg.add_node(source_node, is_start_state=True)
nsdfg.add_nodes_from([s for s in states if s is not source_node])
for s in states:
s.parent = nsdfg
nsdfg.add_nodes_from([s for s in blocks if s is not source_node])
for e in subgraph.edges():
nsdfg.add_edge(e.src, e.dst, e.data)

for e in sdfg.in_edges(source_node):
sdfg.add_edge(e.src, new_state, e.data)
for e in sdfg.out_edges(sink_node):
sdfg.add_edge(new_state, e.dst, e.data)
for e in graph.in_edges(source_node):
graph.add_edge(e.src, new_state, e.data)
for e in graph.out_edges(sink_node):
graph.add_edge(new_state, e.dst, e.data)

sdfg.remove_nodes_from(states)
graph.remove_nodes_from(blocks)

# Add NestedSDFG arrays
for name in read_set | write_set:
Expand Down Expand Up @@ -177,15 +202,15 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS

# Part (2)
if out_state is not None:
extra_state = sdfg.add_state('symbolic_output')
for e in sdfg.out_edges(new_state):
sdfg.add_edge(extra_state, e.dst, e.data)
sdfg.remove_edge(e)
sdfg.add_edge(new_state, extra_state, InterstateEdge(assignments=out_mapping))
extra_state = graph.add_state('symbolic_output')
for e in graph.out_edges(new_state):
graph.add_edge(extra_state, e.dst, e.data)
graph.remove_edge(e)
graph.add_edge(new_state, extra_state, InterstateEdge(assignments=out_mapping))
new_state = extra_state

else:
return_state = states[0]
return_state = blocks[0]

return return_state

Expand Down Expand Up @@ -244,7 +269,8 @@ def _copy_state(sdfg: SDFG,
return state_copy


def find_sdfg_control_flow(cfg: ControlFlowRegion) -> Dict[ControlFlowBlock, Set[ControlFlowBlock]]:
def find_sdfg_control_flow(cfg: ControlFlowRegion) -> Dict[ControlFlowBlock,
Tuple[Set[ControlFlowBlock], ControlFlowBlock]]:
"""
Partitions a CFG to subgraphs that can be nested independently of each other. The method does not nest the
subgraphs but alters the graph; (1) interstate edges are split, (2) scope source/sink nodes that belong to multiple
Expand Down Expand Up @@ -352,16 +378,10 @@ def nest_sdfg_control_flow(sdfg: SDFG, components=None):
:param sdfg: The SDFG to be partitioned.
:param components: An existing partition of the SDFG.
"""

components = components or find_sdfg_control_flow(sdfg)

num_components = len(components)

if num_components < 2:
return

for i, (start, (component, _)) in enumerate(components.items()):
nest_sdfg_subgraph(sdfg, graph.SubgraphView(sdfg, component), start)
regions = list(sdfg.all_control_flow_regions())
for region in regions:
nest_sdfg_subgraph(region.sdfg, SubgraphView(region.sdfg, [region]), region)
sdfg.reset_cfg_list()


def nest_state_subgraph(sdfg: SDFG,
Expand Down
6 changes: 5 additions & 1 deletion dace/transformation/pass_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,10 @@ class ControlFlowRegionPass(Pass):

CATEGORY: str = 'Helper'

apply_to_conditionals = properties.Property(dtype=bool, default=False,
desc='Whether or not to apply to conditional blocks. If false, do ' +
'not apply to conditional blocks, but only their children.')

def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Dict[int, Optional[Any]]]:
"""
Applies the pass to control flow regions of the given SDFG by calling ``apply`` on each region.
Expand All @@ -287,7 +291,7 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[D
"""
result = {}
for region in sdfg.all_control_flow_regions(recursive=True, parent_first=False):
if isinstance(region, ConditionalBlock):
if isinstance(region, ConditionalBlock) and not self.apply_to_conditionals:
continue
retval = self.apply(region, pipeline_results)
if retval is not None:
Expand Down
17 changes: 12 additions & 5 deletions dace/transformation/passes/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,12 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Set[
# The implementation below is faster
# tc: nx.DiGraph = nx.transitive_closure(sdfg.nx)
for n, v in reachable_nodes(cfg.nx):
single_level_reachable[cfg.cfg_id][n] = set(v)
reach = set()
for nd in v:
reach.add(nd)
if isinstance(nd, AbstractControlFlowRegion):
reach.update(nd.all_control_flow_blocks())
single_level_reachable[cfg.cfg_id][n] = reach
if isinstance(cfg, LoopRegion):
single_level_reachable[cfg.cfg_id][n].update(cfg.nodes())

Expand All @@ -166,7 +171,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Set[
result: Dict[ControlFlowBlock, Set[ControlFlowBlock]] = defaultdict(set)
for block in cfg.nodes():
for reached in single_level_reachable[block.parent_graph.cfg_id][block]:
if isinstance(reached, ControlFlowRegion):
if isinstance(reached, AbstractControlFlowRegion):
result[block].update(reached.all_control_flow_blocks())
result[block].add(reached)
if block.parent_graph is not sdfg:
Expand Down Expand Up @@ -516,7 +521,7 @@ def should_reapply(self, modified: ppl.Modifies) -> bool:
return modified & ppl.Modifies.States

def depends_on(self):
return {AccessSets, FindAccessNodes, StateReachability}
return {AccessSets, FindAccessNodes, ControlFlowBlockReachability}

def _find_dominating_write(self,
desc: str,
Expand Down Expand Up @@ -615,7 +620,9 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i
access_nodes: Dict[str, Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]]] = pipeline_results[
FindAccessNodes.__name__][sdfg.cfg_id]

state_reach: Dict[SDFGState, Set[SDFGState]] = pipeline_results[StateReachability.__name__][sdfg.cfg_id]
block_reach: Dict[ControlFlowBlock, Set[ControlFlowBlock]] = pipeline_results[
ControlFlowBlockReachability.__name__
]

anames = sdfg.arrays.keys()
for desc in sdfg.arrays:
Expand Down Expand Up @@ -657,7 +664,7 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i
continue
write_state, write_node = write
dominators = all_doms_transitive[write_state]
reach = state_reach[write_state]
reach = block_reach[write_state.parent_graph.cfg_id][write_state]
for other_write, other_accesses in result[desc].items():
if other_write is not None and other_write[1] is write_node and other_write[0] is write_state:
continue
Expand Down
2 changes: 1 addition & 1 deletion dace/transformation/passes/dead_state_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def apply_pass(self, sdfg: SDFG, _) -> Optional[Set[Union[SDFGState, Edge[Inters
cfg.remove_node(node)
else:
result.add(node)
cfg.remove_node(block)
cfg.remove_node(node)

if not annotated:
return result or None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,5 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Optional[Tuple[int, int]]:
lifted_branches += self._lift_conditionals(sdfg)
if lifted_branches == 0 and lifted_loops == 0:
return None
top_sdfg.reset_cfg_list()
return lifted_loops, lifted_branches
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ class PruneEmptyConditionalBranches(ppl.ControlFlowRegionPass):

CATEGORY: str = 'Simplification'

def __init__(self):
super().__init__()
self.apply_to_conditionals = True

def modifies(self) -> ppl.Modifies:
return ppl.Modifies.CFG

Expand Down Expand Up @@ -58,5 +62,8 @@ def apply(self, region: ControlFlowRegion, _) -> Optional[int]:
region.parent_graph.add_edge(replacement_node_before, replacement_node_after, InterstateEdge())
region.parent_graph.remove_node(region)

return removed_branches if removed_branches > 0 else None
if removed_branches > 0:
region.reset_cfg_list()
return removed_branches
return None

5 changes: 5 additions & 0 deletions dace/transformation/passes/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def __init__(self,
validate_all: bool = False,
skip: Optional[Set[str]] = None,
verbose: bool = False,
no_inline_function_call_regions: bool = False,
no_inline_named_regions: bool = False,
pass_options: Optional[Dict[str, Any]] = None):
if skip:
passes: List[ppl.Pass] = [p() for p in SIMPLIFY_PASSES if p.__name__ not in skip]
Expand All @@ -88,6 +90,9 @@ def __init__(self,
else:
self.verbose = verbose

self.no_inline_function_call_regions = no_inline_function_call_regions
self.no_inline_named_regions = no_inline_named_regions

pass_opts = {
'no_inline_function_call_regions': self.no_inline_function_call_regions,
'no_inline_named_regions': self.no_inline_named_regions,
Expand Down
8 changes: 5 additions & 3 deletions tests/fortran/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import dace.frontend.fortran.ast_transforms as ast_transforms
import dace.frontend.fortran.ast_utils as ast_utils
import dace.frontend.fortran.ast_internal_classes as ast_internal_classes
from dace.sdfg.state import LoopRegion


def test_fortran_frontend_array_access():
Expand Down Expand Up @@ -199,9 +200,10 @@ def test_fortran_frontend_memlet_in_map_test():
"""
sdfg = fortran_parser.create_sdfg_from_string(test_string, "memlet_range_test")
sdfg.simplify()
# Expect that start is begin of for loop -> only one out edge to guard defining iterator variable
assert len(sdfg.out_edges(sdfg.start_state)) == 1
iter_var = symbolic.symbol(list(sdfg.out_edges(sdfg.start_state)[0].data.assignments.keys())[0])
# Expect that the start block is a loop
loop = sdfg.nodes()[0]
assert isinstance(loop, LoopRegion)
iter_var = symbolic.pystr_to_symbolic(loop.loop_variable)

for state in sdfg.states():
if len(state.nodes()) > 1:
Expand Down
Loading

0 comments on commit d010620

Please sign in to comment.