Skip to content

Commit

Permalink
Modified SDFGState.unordered_arglist() (#1708)
Browse files Browse the repository at this point in the history
This PR fixes the way how arguments are detected in scopes.

Technically this only affects GPU code generation, but it is a side
effect of how the code is generated.
In GPU mode a `Map` is translated into one kernel, for this reason a
signature must be computed (this is the reason why CPU code generation
is not affected, no function call is produced).
To compute this signature the `unsorted_arglist()` function scans what
is needed.
However, this was implemented not correctly.
Assume that AccessNode for array `A` is outside the map and inside the
map a temporary scalar, `tmp_in` is defined and initialized to `tmp_in =
A[__i0, __i1]`, see also this image:


![argliost_situation](https://github.com/user-attachments/assets/fdf54dea-4ef5-49be-8ce2-33b78ce5962d)

If the `data` property of the Memlet that connects the MapEntry with the
AccessNode for `tmp_in` is referencing `A` then the (old) function would
find that `A` is needed inside, although there is no AccessNode for `A`
inside the map.
If however, this Memlet referrers `tmp_in` (which is not super standard,
but should be allowed), then the old version would not pick up.
This would then lead to a code generation error.

This PR modifies the function such that such cases are handled.
This is done by following all edges that are adjacent to the MapEntry
(from the inside) to where the are actually originate.
  • Loading branch information
philip-paul-mueller authored Oct 25, 2024
1 parent 057a680 commit 813a2f4
Show file tree
Hide file tree
Showing 2 changed files with 247 additions and 10 deletions.
60 changes: 50 additions & 10 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,8 @@ def unordered_arglist(self,
for node in self.nodes():
if isinstance(node, nd.AccessNode):
descs[node.data] = node.desc(sdfg)
# NOTE: In case of multiple nodes of the same data this will
# override previously found nodes.
descs_with_nodes[node.data] = node
if isinstance(node.desc(sdfg), dt.Scalar):
scalars_with_nodes.add(node.data)
Expand All @@ -865,19 +867,57 @@ def unordered_arglist(self,
else:
data_args[node.data] = desc

# Add data arguments from memlets, if do not appear in any of the nodes
# (i.e., originate externally)
# Add data arguments from memlets, if do not appear in any of the nodes (i.e., originate externally)
# TODO: Investigate is scanning the adjacent edges of the input and output connectors is better.
for edge in self.edges():
if edge.data.data is not None and edge.data.data not in descs:
desc = sdfg.arrays[edge.data.data]
if isinstance(desc, dt.Scalar):
# Ignore code->code edges.
if (isinstance(edge.src, nd.CodeNode) and isinstance(edge.dst, nd.CodeNode)):
continue
if edge.data.is_empty():
continue

elif edge.data.data not in descs:
# The edge reads data from the outside, and the Memlet is directly indicating what is read.
if (isinstance(edge.src, nd.CodeNode) and isinstance(edge.dst, nd.CodeNode)):
continue # Ignore code->code edges.
additional_descs = {edge.data.data: sdfg.arrays[edge.data.data]}

elif isinstance(edge.dst, (nd.AccessNode, nd.CodeNode)) and isinstance(edge.src, nd.EntryNode):
# Special case from the above; An AccessNode reads data from the Outside, but
# the Memlet references the data on the inside. Thus we have to follow the data
# to where it originates from.
# NOTE: We have to use a memlet path, because we have to go "against the flow"
# Furthermore, in a valid SDFG the data will only come from one source anyway.
top_source_edge = self.graph.memlet_path(edge)[0]
if not isinstance(top_source_edge.src, nd.AccessNode):
continue
additional_descs = (
{top_source_edge.src.data: top_source_edge.src.desc(sdfg)}
if top_source_edge.src.data not in descs
else {}
)

elif isinstance(edge.dst, nd.ExitNode) and isinstance(edge.src, (nd.AccessNode, nd.CodeNode)):
# Same case as above, but for outgoing Memlets.
# NOTE: We have to use a memlet tree here, because the data could potentially
# go to multiple sources. We have to do it this way, because if we would call
# `memlet_tree()` here, then we would just get the edge back.
additional_descs = {}
connector_to_look = "OUT_" + edge.dst_conn[3:]
for oedge in self.graph.out_edges_by_connector(edge.dst, connector_to_look):
if (
(not oedge.data.is_empty()) and (oedge.data.data not in descs)
and (oedge.data.data not in additional_descs)
):
additional_descs[oedge.data.data] = sdfg.arrays[oedge.data.data]

else:
# Case is ignored.
continue

scalar_args[edge.data.data] = desc
# Now processing the list of newly found data.
for aname, additional_desc in additional_descs.items():
if isinstance(additional_desc, dt.Scalar):
scalar_args[aname] = additional_desc
else:
data_args[edge.data.data] = desc
data_args[aname] = additional_desc

# Loop over locally-used data descriptors
for name, desc in descs.items():
Expand Down
197 changes: 197 additions & 0 deletions tests/codegen/argumet_signature_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import dace
import copy

def test_argument_signature_test():
"""Tests if the argument signature is computed correctly.
The test is focused on if data dependencies are picked up if they are only
referenced indirectly. This effect is only directly visible for GPU.
The test also runs on GPU, but will only compile for GPU.
"""

def make_sdfg() -> dace.SDFG:
sdfg = dace.SDFG("Repr")
state = sdfg.add_state(is_start_block=True)
N = dace.symbol(sdfg.add_symbol("N", dace.int32))
for name in "BC":
sdfg.add_array(
name=name,
dtype=dace.float64,
shape=(N, N),
strides=(N, 1),
transient=False,
)

# `A` uses a stride that is not used by any of the other arrays.
# However, the stride is used if we want to index array `A`.
second_stride_A = dace.symbol(sdfg.add_symbol("second_stride_A", dace.int32))
sdfg.add_array(
name="A",
dtype=dace.float64,
shape=(N,),
strides=(second_stride_A,),
transient=False,

)

# Also array `D` uses a stride that is not used by any other array.
second_stride_D = dace.symbol(sdfg.add_symbol("second_stride_D", dace.int32))
sdfg.add_array(
name="D",
dtype=dace.float64,
shape=(N, N),
strides=(second_stride_D, 1),
transient=False,

)

# Simplest way to generate a mapped Tasklet, we will later modify it.
state.add_mapped_tasklet(
"computation",
map_ranges={"__i0": "0:N", "__i1": "0:N"},
inputs={
"__in0": dace.Memlet("A[__i1]"),
"__in1": dace.Memlet("B[__i0, __i1]"),
},
code="__out = __in0 + __in1",
outputs={"__out": dace.Memlet("C[__i0, __i1]")},
external_edges=True,
)

# Instead of going from the MapEntry to the Tasklet we will go through
# an temporary AccessNode that is only used inside the map scope.
# Thus there is no direct reference to `A` inside the map scope, that would
# need `second_stride_A`.
sdfg.add_scalar("tmp_in", transient=True, dtype=dace.float64)
tmp_in = state.add_access("tmp_in")
for e in state.edges():
if e.dst_conn == "__in0":
iedge = e
break
state.add_edge(
iedge.src,
iedge.src_conn,
tmp_in,
None,
# The important thing is that the Memlet, that connects the MapEntry with the
# AccessNode, does not refers to the memory outside (its source) but to the transient
# inside (its destination)
dace.Memlet(data="tmp_in", subset="0", other_subset="__i1"), # This does not work!
#dace.Memlet(data="A", subset="__i1", other_subset="0"), # This would work!
)
state.add_edge(
tmp_in,
None,
iedge.dst,
iedge.dst_conn,
dace.Memlet(f"{tmp_in.data}[0]"),
)
state.remove_edge(iedge)

# Here we are doing something similar as for `A`, but this time for the output.
# The output of the Tasklet is stored inside a temporary scalar.
# From that scalar we then go to `C`, here the Memlet on the inside is still
# referring to `C`, thus it is referenced directly.
# We also add a second output that goes to `D` , but the inner Memlet does
# not refer to `D` but to the temporary. Thus there is no direct mention of
# `D` inside the map scope.
sdfg.add_scalar("tmp_out", transient=True, dtype=dace.float64)
tmp_out = state.add_access("tmp_out")
for e in state.edges():
if e.src_conn == "__out":
oedge = e
assert oedge.data.data == "C"
break

state.add_edge(
oedge.src,
oedge.src_conn,
tmp_out,
None,
dace.Memlet(data="tmp_out", subset="0"),
)
state.add_edge(
tmp_out,
None,
oedge.dst,
oedge.dst_conn,
dace.Memlet(data="C", subset="__i0, __i1"),
)

# Now we create a new output that uses `tmp_out` but goes into `D`.
# The memlet on the inside will not use `D` but `tmp_out`.
state.add_edge(
tmp_out,
None,
oedge.dst,
"IN_D",
dace.Memlet(data=tmp_out.data, subset="0", other_subset="__i1, __i0"),
)
state.add_edge(
oedge.dst,
"OUT_D",
state.add_access("D"),
None,
dace.Memlet(data="D", subset="__i0, __i1", other_subset="0"),
)
oedge.dst.add_in_connector("IN_D", force=True)
oedge.dst.add_out_connector("OUT_D", force=True)
state.remove_edge(oedge)

# Without this the test does not work properly
# It is related to [Issue#1703](https://github.com/spcl/dace/issues/1703)
sdfg.validate()
for edge in state.edges():
edge.data.try_initialize(edge=edge, sdfg=sdfg, state=state)

for array in sdfg.arrays.values():
if isinstance(array, dace.data.Array):
array.storage = dace.StorageType.GPU_Global
else:
array.storage = dace.StorageType.Register
sdfg.apply_gpu_transformations(simplify=False)
sdfg.validate()

return sdfg

# Build the SDFG
sdfg = make_sdfg()

map_entry = None
for state in sdfg.states():
for node in state.nodes():
if isinstance(node, dace.nodes.MapEntry):
map_entry = node
break
if map_entry is not None:
break

# Now get the argument list of the map.
res_arglist = { k:v for k, v in state.scope_subgraph(map_entry).arglist().items()}

ref_arglist = {
'A': dace.data.Array,
'B': dace.data.Array,
'C': dace.data.Array,
'D': dace.data.Array,
'N': dace.data.Scalar,
'second_stride_A': dace.data.Scalar,
'second_stride_D': dace.data.Scalar,
}

assert len(ref_arglist) == len(res_arglist), f"Expected {len(ref_arglist)} but got {len(res_arglist)}"
for aname in ref_arglist.keys():
atype_ref = ref_arglist[aname]
atype_res = res_arglist[aname]
assert isinstance(atype_res, atype_ref), f"Expected '{aname}' to have type {atype_ref}, but it had {type(atype_res)}."

# If we have cupy we will also compile it.
try:
import cupy as cp
except ImportError:
return

csdfg = sdfg.compile()

if __name__ == "__main__":
test_argument_signature_test()

0 comments on commit 813a2f4

Please sign in to comment.