Skip to content

Commit

Permalink
Add loop regions to the frontend's capabilities (#1475)
Browse files Browse the repository at this point in the history
This PR lets the Python and Fortran frontends (optionally) generate
`LoopRegion`s for DaCe programs. This forms the third core element of
the [plan to make loops first class citizens of
SDFGs](https://github.com/orgs/spcl/projects/10).

This PR is fully backwards compatible. `LoopRegion`s are always
generated from new Python DaCe programs, and the legacy way of
constructing a while / for loop is gone to remove complexity. To provide
backwards compatibility, these `LoopRegion`s are by default immediately
inlined into a traditional single level state machine loop as soon as
program parsing is completed, before simplification and / or validation.
However, an optional boolean parameter `use_experimental_cfg_blocks` can
be set to True when declaring a DaCe program in Python to enable their
use, which skips this inlining step.

Example use:
```Python
import dace
import numpy

N = dace.symbol('N')

@dace.program(use_experimental_cfg_blocks=True):
def mat_mult(A: dace.float64[N, N], B: dace.float64[N, N]):
    return A @ B

# OR:
mat_mult.use_experimental_cfg_blocks = True
sdfg = mat_mult.to_sdfg()
```

The Fortran frontend similarly only utilizes `LoopRegions` if an
additional parameter `use_experimenatl_cfg_blocks` is passed to the
parser together with the program.

Many passes and transformations (including in simplify) do not yet have
the capability of handling the new, hierarchical SDFGs. To not break the
pipeline and to provide backwards compatibility, a new decorator
`@single_level_sdfg_only` has been added, which can be (and has been)
placed over any pass or transformation that is not compatible with the
new style SDFGs. Passes annotated with this decorator are skipped in all
pipelines where they occur and instead generate warnings that they were
skipped due to compatibility issues.

For more information on `LoopRegion`s please refer to the [PR that
introduced them](#1407).

**Important Note about disabled tests:**
Certain Python frontend loop tests have been disabled. Specifically,
this concerns tests where either the loop structure (using
continue/break) or other conditional statements cause the generation of
control flow that looks irregular before the simplification pass is
applied. The reason being that the frontend generates branches with one
branch condition being set to constant `False` when generating continue
/ break / return, or while/for-else clauses. These branches are
trivially removed during simplification, but not running simplification
(as part of our CI does) leads to irregular control flow which is
handled poorly by our codegen-time control flow detection. This error
has so far gone unnoticed in these tests because of sheer luck, but is
now exposed through a ever so slightly different state machine being
generated by control flow region and loop inlining.

The goal is for a subsequent PR to completely adapt codegen to make use
of the control flow region constructs, thereby fixing this issue and
re-enabling the tests. For more information about the issue, see #635
and #1586.

Linked to:
https://github.com/orgs/spcl/projects/10/views/4?pane=issue&itemId=42047238
and
https://github.com/orgs/spcl/projects/10/views/4?pane=issue&itemId=42151188
  • Loading branch information
phschaad authored Jun 26, 2024
1 parent 6a490ec commit ecae262
Show file tree
Hide file tree
Showing 81 changed files with 2,315 additions and 1,210 deletions.
1 change: 1 addition & 0 deletions dace/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def generate_code(sdfg, validate=True) -> List[CodeObject]:
# Convert any loop constructs with hierarchical loop regions into simple 1-level state machine loops.
# TODO (later): Adapt codegen to deal with hierarchical CFGs instead.
sdutils.inline_loop_blocks(sdfg)
sdutils.inline_control_flow_regions(sdfg)

# Before generating the code, run type inference on the SDFG connectors
infer_types.infer_connector_types(sdfg)
Expand Down
3 changes: 2 additions & 1 deletion dace/codegen/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,8 @@ def as_cpp(self, codegen, symbols) -> str:

update = ''
if self.update is not None:
update = f'{self.itervar} = {self.update}'
cppupdate = unparse_interstate_edge(self.update, sdfg, codegen=codegen)
update = f'{self.itervar} = {cppupdate}'

expr = f'{preinit}\nfor ({init}; {cond}; {update}) {{\n'
expr += _clean_loop_body(self.body.as_cpp(codegen, symbols))
Expand Down
262 changes: 151 additions & 111 deletions dace/frontend/fortran/fortran_parser.py

Large diffs are not rendered by default.

14 changes: 13 additions & 1 deletion dace/frontend/python/nested_call.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
import dace
from dace.sdfg import SDFG, SDFGState
from typing import Optional, TYPE_CHECKING

if TYPE_CHECKING:
from dace.frontend.python.newast import ProgramVisitor
else:
ProgramVisitor = 'dace.frontend.python.newast.ProgramVisitor'


class NestedCall():
Expand All @@ -18,7 +24,13 @@ def _cos_then_max(pv, sdfg, state, a: str):
# return a tuple of the nest object and the result
return nest, result
"""
def __init__(self, pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState):
state: SDFGState
last_state: Optional[SDFGState]
pv: ProgramVisitor
sdfg: SDFG
count: int

def __init__(self, pv: ProgramVisitor, sdfg: SDFG, state: SDFGState):
self.pv = pv
self.sdfg = sdfg
self.state = state
Expand Down
433 changes: 240 additions & 193 deletions dace/frontend/python/newast.py

Large diffs are not rendered by default.

14 changes: 11 additions & 3 deletions dace/frontend/python/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dace import data, dtypes, hooks, symbolic
from dace.config import Config
from dace.frontend.python import (newast, common as pycommon, cached_program, preprocessing)
from dace.sdfg import SDFG
from dace.sdfg import SDFG, utils as sdutils
from dace.data import create_datadescriptor, Data

try:
Expand Down Expand Up @@ -152,7 +152,8 @@ def __init__(self,
regenerate_code: bool = True,
recompile: bool = True,
distributed_compilation: bool = False,
method: bool = False):
method: bool = False,
use_experimental_cfg_blocks: bool = False):
from dace.codegen import compiled_sdfg # Avoid import loops

self.f = f
Expand All @@ -172,6 +173,7 @@ def __init__(self,
self.recreate_sdfg = recreate_sdfg
self.regenerate_code = regenerate_code
self.recompile = recompile
self.use_experimental_cfg_blocks = use_experimental_cfg_blocks
self.distributed_compilation = distributed_compilation

self.global_vars = _get_locals_and_globals(f)
Expand Down Expand Up @@ -491,6 +493,11 @@ def _parse(self, args, kwargs, simplify=None, save=False, validate=False) -> SDF
# Obtain DaCe program as SDFG
sdfg, cached = self._generate_pdp(args, kwargs, simplify=simplify)

if not self.use_experimental_cfg_blocks:
sdutils.inline_loop_blocks(sdfg)
sdutils.inline_control_flow_regions(sdfg)
sdfg.using_experimental_blocks = self.use_experimental_cfg_blocks

# Apply simplification pass automatically
if not cached and (simplify == True or
(simplify is None and Config.get_bool('optimizer', 'automatic_simplification'))):
Expand Down Expand Up @@ -801,7 +808,8 @@ def get_program_hash(self, *args, **kwargs) -> cached_program.ProgramCacheKey:
_, key = self._load_sdfg(None, *args, **kwargs)
return key

def _generate_pdp(self, args: Tuple[Any], kwargs: Dict[str, Any], simplify: Optional[bool] = None) -> SDFG:
def _generate_pdp(self, args: Tuple[Any], kwargs: Dict[str, Any],
simplify: Optional[bool] = None) -> Tuple[SDFG, bool]:
""" Generates the parsed AST representation of a DaCe program.
:param args: The given arguments to the program.
Expand Down
3 changes: 3 additions & 0 deletions dace/frontend/python/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,9 @@ def _add_exits(self, until_loop_end: bool, only_one: bool = False) -> List[ast.A
for stmt in reversed(self.with_statements):
if until_loop_end and not isinstance(stmt, (ast.With, ast.AsyncWith)):
break
elif not until_loop_end and isinstance(stmt, (ast.For, ast.While)):
break

for mgrname, mgr in reversed(self.context_managers[stmt]):
# Call __exit__ (without exception management all three arguments are set to None)
exit_call = ast.copy_location(ast.parse(f'{mgrname}.__exit__(None, None, None)').body[0], stmt)
Expand Down
19 changes: 11 additions & 8 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import warnings
from functools import reduce
from numbers import Number, Integral
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union, TYPE_CHECKING

import dace
from dace.codegen.tools import type_inference
Expand All @@ -28,7 +28,10 @@

Size = Union[int, dace.symbolic.symbol]
Shape = Sequence[Size]
ProgramVisitor = 'dace.frontend.python.newast.ProgramVisitor'
if TYPE_CHECKING:
from dace.frontend.python.newast import ProgramVisitor
else:
ProgramVisitor = 'dace.frontend.python.newast.ProgramVisitor'


def normalize_axes(axes: Tuple[int], max_dim: int) -> List[int]:
Expand Down Expand Up @@ -971,8 +974,8 @@ def _pymax(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: Union[str, Numbe
for i, b in enumerate(args):
if i > 0:
pv._add_state('__min2_%d' % i)
pv.last_state.set_default_lineinfo(pv.current_lineinfo)
current_state = pv.last_state
pv.last_block.set_default_lineinfo(pv.current_lineinfo)
current_state = pv.last_block
left_arg = _minmax2(pv, sdfg, current_state, left_arg, b, ismin=False)
return left_arg

Expand All @@ -986,8 +989,8 @@ def _pymin(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: Union[str, Numbe
for i, b in enumerate(args):
if i > 0:
pv._add_state('__min2_%d' % i)
pv.last_state.set_default_lineinfo(pv.current_lineinfo)
current_state = pv.last_state
pv.last_block.set_default_lineinfo(pv.current_lineinfo)
current_state = pv.last_block
left_arg = _minmax2(pv, sdfg, current_state, left_arg, b)
return left_arg

Expand Down Expand Up @@ -3355,7 +3358,7 @@ def _create_subgraph(visitor: ProgramVisitor,
cond_state.add_nedge(r, w, dace.Memlet("{}[0]".format(r)))
true_state = sdfg.add_state(label=cond_state.label + '_true')
state = true_state
visitor.last_state = state
visitor.last_block = state
cond = name
cond_else = 'not ({})'.format(cond)
sdfg.add_edge(cond_state, true_state, dace.InterstateEdge(cond))
Expand All @@ -3374,7 +3377,7 @@ def _create_subgraph(visitor: ProgramVisitor,
dace.Memlet.from_array(arg, sdfg.arrays[arg]))
if has_where and isinstance(where, str) and where in sdfg.arrays.keys():
visitor._add_state(label=cond_state.label + '_true')
sdfg.add_edge(cond_state, visitor.last_state, dace.InterstateEdge(cond_else))
sdfg.add_edge(cond_state, visitor.last_block, dace.InterstateEdge(cond_else))
else:
# Map needed
if has_where:
Expand Down
4 changes: 2 additions & 2 deletions dace/sdfg/infer_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def infer_connector_types(sdfg: SDFG):
:param sdfg: The SDFG to infer.
"""
# Loop over states, and in a topological sort over each state's nodes
for state in sdfg.nodes():
for state in sdfg.states():
for node in dfs_topological_sort(state):
# Try to infer input connector type from node type or previous edges
for e in state.in_edges(node):
Expand Down Expand Up @@ -168,7 +168,7 @@ def set_default_schedule_and_storage_types(scope: Union[SDFG, SDFGState, nodes.E

if isinstance(scope, SDFG):
# Set device for default top-level schedules and storages
for state in scope.nodes():
for state in scope.states():
set_default_schedule_and_storage_types(state,
parent_schedules,
use_parent_schedule=use_parent_schedule,
Expand Down
22 changes: 20 additions & 2 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from dace.frontend.python import astutils, wrappers
from dace.sdfg import nodes as nd
from dace.sdfg.graph import OrderedDiGraph, Edge, SubgraphView
from dace.sdfg.state import SDFGState, ControlFlowRegion
from dace.sdfg.state import ControlFlowBlock, SDFGState, ControlFlowRegion
from dace.sdfg.propagation import propagate_memlets_sdfg
from dace.distr_types import ProcessGrid, SubArray, RedistrArray
from dace.dtypes import validate_name
Expand Down Expand Up @@ -183,7 +183,7 @@ class InterstateEdge(object):
desc="Assignments to perform upon transition (e.g., 'x=x+1; y = 0')")
condition = CodeProperty(desc="Transition condition", default=CodeBlock("1"))

def __init__(self, condition: CodeBlock = None, assignments=None):
def __init__(self, condition: Optional[Union[CodeBlock, str, ast.AST, list]] = None, assignments=None):
if condition is None:
condition = CodeBlock("1")

Expand Down Expand Up @@ -452,6 +452,9 @@ class SDFG(ControlFlowRegion):
desc='Mapping between callback name and its original callback '
'(for when the same callback is used with a different signature)')

using_experimental_blocks = Property(dtype=bool, default=False,
desc="Whether the SDFG contains experimental control flow blocks")

def __init__(self,
name: str,
constants: Dict[str, Tuple[dt.Data, Any]] = None,
Expand Down Expand Up @@ -509,6 +512,8 @@ def __init__(self,
self._orig_name = name
self._num = 0

self._sdfg = self

def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
Expand Down Expand Up @@ -2220,6 +2225,7 @@ def compile(self, output_file=None, validate=True) -> 'CompiledSDFG':
# Convert any loop constructs with hierarchical loop regions into simple 1-level state machine loops.
# TODO (later): Adapt codegen to deal with hierarchical CFGs instead.
sdutils.inline_loop_blocks(sdfg)
sdutils.inline_control_flow_regions(sdfg)

# Rename SDFG to avoid runtime issues with clashing names
index = 0
Expand Down Expand Up @@ -2680,3 +2686,15 @@ def make_array_memlet(self, array: str):
:return: a Memlet that fully transfers array
"""
return dace.Memlet.from_array(array, self.data(array))

def recheck_using_experimental_blocks(self) -> bool:
found_experimental_block = False
for node, graph in self.root_sdfg.all_nodes_recursive():
if isinstance(graph, ControlFlowRegion) and not isinstance(graph, SDFG):
found_experimental_block = True
break
if isinstance(node, ControlFlowBlock) and not isinstance(node, SDFGState):
found_experimental_block = True
break
self.root_sdfg.using_experimental_blocks = found_experimental_block
return found_experimental_block
Loading

0 comments on commit ecae262

Please sign in to comment.