Skip to content

Commit

Permalink
[frontend] Add subgraph division and device type (#446)
Browse files Browse the repository at this point in the history
  • Loading branch information
WuXintong123 authored Jan 7, 2025
1 parent 91bbd57 commit 5093adf
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 56 deletions.
2 changes: 2 additions & 0 deletions frontend/Python/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from .graph import Graph, TensorDType, TensorMeta
from .graph.operation import *
from .graph.transform import maxpool2d_simplify
from .graph.type import *


class DynamoCompiler:
Expand Down Expand Up @@ -310,6 +311,7 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]):
fake_params,
self._ops_registry,
self._func_name,
DeviceType.CPU,
self._verbose,
)
param_nodes = []
Expand Down
14 changes: 8 additions & 6 deletions frontend/Python/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(
fake_params: List[TensorMeta],
ops_registry: dict,
func_name: str,
device: DeviceType = DeviceType.CPU,
verbose=False,
) -> None:
"""
Expand All @@ -124,7 +125,7 @@ def __init__(
self._inputs = inputs
self.node_table: Dict[str, Op] = {}
self._fake_params = fake_params
self.device = "cpu"
self.device = device
self._imported_module = None
self._verbose = verbose
self._ops_registry = ops_registry
Expand Down Expand Up @@ -244,11 +245,11 @@ def init_op_group(self):
- None
"""
for i, op in enumerate(self._body):
if isinstance(op, PlaceholderOp):
if isinstance(op, PlaceholderOp) or isinstance(op, OutputOp):
continue
group = [op]
subgraph_name = "subgraph{}".format(i)
self.group_map_device[subgraph_name] = DeviceType.UNKNOW
self.group_map_device[subgraph_name] = DeviceType.CPU
self.op_groups[subgraph_name] = group

def fuse_ops(self, pattern_list: List[FunctionType]):
Expand All @@ -266,9 +267,6 @@ def fuse_ops(self, pattern_list: List[FunctionType]):
# 1. fuse ops adapt for DSA(hardware dependent)
# 2. common fuse strategy(hardware independent)

# Initialize operation groups
self.init_op_group()

# Apply fusion patterns
for pattern_func in pattern_list:
pattern_func(self)
Expand Down Expand Up @@ -311,6 +309,8 @@ def lower_to_top_level_ir(self):
self._inputs,
self._func_name,
self._ops_registry,
False,
self.device,
verbose=self._verbose,
)
self._imported_module = fx_importer.import_graph()
Expand Down Expand Up @@ -424,6 +424,7 @@ def __init__(
func_name: str,
ops_registry: dict,
do_param_pack: bool = False,
device: DeviceType = DeviceType.CPU,
verbose=False,
):
"""
Expand All @@ -439,6 +440,7 @@ def __init__(
ops_registry = {}
self._symbol_table = {}
self._body = body
self._device = device
self._func_name = func_name
self._params = params
self._inputs = inputs
Expand Down
146 changes: 101 additions & 45 deletions frontend/Python/graph/graph_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# ===---------------------------------------------------------------------------

from mlir import ir
from collections import deque, defaultdict

from .graph import Graph, GraphImporter, TensorMeta
from .operation import FuncOp, CallOp, PlaceholderOp, OutputOp, GetItemOp
Expand All @@ -40,6 +41,7 @@ class GraphDriver:
- _subgraphs_outputs (dict): A dictionary mapping subgraph names to their
output op's result.
"""

def __init__(self, graph: Graph) -> None:
"""
Initialize the GraphDriver object with a given computational graph.
Expand All @@ -52,6 +54,11 @@ def __init__(self, graph: Graph) -> None:
- None
"""
self._graph = graph
self._subgraph_dependencies = {
subgraph_name: set()
for subgraph_name in list(self._graph.op_groups.keys())
}
self._call_table = {}
(
self._subgraphs,
self._subgraphs_inputs,
Expand Down Expand Up @@ -94,14 +101,15 @@ def build_subgraph_by_group(self):
if isinstance(node, OutputOp):
for arg in node.args:
output_node.append(arg)
# Identify outputs for each subgraph

# Identify outputs for each subgraph and build dependencies between subgraphs
for subgraph_name in self._graph.op_groups.keys():
subgraphs_outputs[subgraph_name] = []
for op in self._graph.op_groups[subgraph_name]:
for key in subgraphs_inputs.keys():
if op.name in subgraphs_inputs[key]:
subgraphs_outputs[subgraph_name].append(op.name)
self._subgraph_dependencies[subgraph_name].add(key)
if (op.name in output_node) and (
op.name not in subgraphs_outputs[subgraph_name]
):
Expand All @@ -112,6 +120,7 @@ def build_subgraph_by_group(self):
for subgraph_name in self._graph.op_groups.keys():
subgraph_input = []
subgraph_body = []
subgraph_device = self._graph.group_map_device[subgraph_name]

# Construct input placeholder nodes
for inp in subgraphs_inputs[subgraph_name]:
Expand All @@ -127,11 +136,11 @@ def build_subgraph_by_group(self):
if inp in node._parents:
placeholder_node.add_children(op.name)
subgraph_body.append(placeholder_node)

# Add operations to subgraph body
for op in self._graph.op_groups[subgraph_name]:
subgraph_body.append(op)

# Construct output node
output_node = OutputOp()
output_node.name = "output"
Expand All @@ -142,7 +151,12 @@ def build_subgraph_by_group(self):

# Create subgraph and add it to the dictionary
subgraph = Graph(
subgraph_input, [], self._graph._ops_registry, subgraph_name, verbose=self._graph._verbose
subgraph_input,
[],
self._graph._ops_registry,
subgraph_name,
subgraph_device,
verbose=self._graph._verbose,
)
subgraph.body = subgraph_body
for op in subgraph_body:
Expand All @@ -151,6 +165,38 @@ def build_subgraph_by_group(self):

return subgraphs, subgraphs_inputs, subgraphs_outputs

def topological_sort_subgraph(self):
"""
Performs topological sorting on the subgraphs based on their dependencies.
Args:
- graph (Graph): The graph from which subgraphs are constructed.
Returns:
- list: A list of subgraph names in topological order if the graph is acyclic; otherwise, None.
"""
# Calculate in degree of each subgraph
in_degree = {
subgraph_name: 0 for subgraph_name in list(self._subgraphs.keys())
}
for src, dests in self._subgraph_dependencies.items():
for dest in dests:
in_degree[dest] += 1
# Topological sorting
queue = deque([node for node in in_degree if in_degree[node] == 0])
topo_order = []
while queue:
node = queue.popleft()
topo_order.append(node)
for child in self._subgraph_dependencies[node]:
in_degree[child] -= 1
if in_degree[child] == 0:
queue.append(child)
# TODO: If the custom subgraph partitioning is illegal, further partition the subgraph to make it valid.
return (
topo_order
if len(topo_order) == len(list(self._subgraphs.keys()))
else None
)

def construct_main_graph(self, do_param_pack=False):
"""
Constructs the main computational graph by incorporating subgraphs' call
Expand All @@ -172,7 +218,7 @@ def construct_main_graph(self, do_param_pack=False):
self._graph._fake_params,
self._graph._ops_registry,
self._graph._func_name,
self._graph._verbose
self._graph._verbose,
)

# Adding FuncOp nodes for each subgraph
Expand All @@ -189,53 +235,63 @@ def construct_main_graph(self, do_param_pack=False):
func_node.tensor_meta["dtype"].append(
self._graph.node_table[output].tensor_meta["dtype"]
)
main_graph.body.append(func_node)
main_graph.add_node(func_node)

# Adding placeholder operations from the original graph
for op in self._graph.body:
if isinstance(op, PlaceholderOp):
main_graph.body.append(op)

# TODO: analysis topology order to sort subgraph call.
if len(self._subgraphs) == 1:
# Adding CallOp to invoke the single subgraph
main_graph.add_node(op)
# Analysis topology order to sort subgraph call.
topo_order = self.topological_sort_subgraph()
if topo_order == None:
print("Error : Graph Partitioning is illegal!")
return None
# Adding CallOp to invoke the single subgraph
for i, subgraph_name in enumerate(topo_order):
call_node = CallOp()
call_node.name = "call0"
call_node.call_func_name = list(self._subgraphs.keys())[0]
call_node.name = "call{}".format(i)
call_node.call_func_name = subgraph_name
call_node.tensor_meta = {"shape": [], "dtype": []}
for inp in list(self._subgraphs_inputs.values())[0]:
call_node.add_argument(inp)
for output in list(self._subgraphs_outputs.values())[0]:
for inp in self._subgraphs_inputs[subgraph_name]:
if inp in main_graph.node_table:
call_node.add_argument(inp)
continue
for key, value in self._subgraphs_outputs.items():
if inp in value:
call_node.add_argument(
arg=self._call_table[key].name,
arg_index=value.index(inp),
)
break
for output in self._subgraphs_outputs[subgraph_name]:
call_node.tensor_meta["shape"].append(
self._graph.node_table[output].tensor_meta["shape"]
)
call_node.tensor_meta["dtype"].append(
self._graph.node_table[output].tensor_meta["dtype"]
)
main_graph.body.append(call_node)

# Adding GetItemOps to retrieve individual output tensors
output_node = OutputOp()
for i, output in enumerate(list(self._subgraphs_outputs.values())[0]):
getitem_node = GetItemOp()
getitem_node.add_argument(call_node.name)
getitem_node.add_argument(i)
getitem_node.name = "getitem{}".format(i)
output_node.add_argument(getitem_node.name)
main_graph.body.append(getitem_node)

# Marking the final output of the main graph
output_node.name = "output"
main_graph.body.append(output_node)

# Importing the main graph
with ir.Location.unknown(ir.Context()):
main_importer = GraphImporter(
main_graph.body,
main_graph._fake_params,
main_graph._inputs,
main_graph._func_name,
main_graph._ops_registry,
do_param_pack,
)
return main_importer.import_main_graph()
self._call_table[subgraph_name] = call_node
main_graph.add_node(call_node)
# Adding GetItemOps to retrieve individual output tensors
output_node = OutputOp()
for i, output in enumerate(self._subgraphs_outputs[topo_order[-1]]):
getitem_node = GetItemOp()
getitem_node.add_argument(call_node.name)
getitem_node.add_argument(i)
getitem_node.name = "getitem{}".format(i)
output_node.add_argument(getitem_node.name)
main_graph.add_node(getitem_node)
# Marking the final output of the main graph
output_node.name = "output"
main_graph.add_node(output_node)
# Importing the main graph
with ir.Location.unknown(ir.Context()):
main_importer = GraphImporter(
main_graph.body,
main_graph._fake_params,
main_graph._inputs,
main_graph._func_name,
main_graph._ops_registry,
do_param_pack,
)
return main_importer.import_main_graph()
4 changes: 3 additions & 1 deletion frontend/Python/graph/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ def __init__(self) -> None:
self._op_type: OpType = None
self._children: List[str] = []
self._parents: List[str] = []
self._args_index = []

def add_argument(self, arg):
def add_argument(self, arg, arg_index=0):
"""
Add an input argument to the operation node.
Expand All @@ -96,6 +97,7 @@ def add_argument(self, arg):
The input argument to be added.
"""
self._arguments.append(arg)
self._args_index.append(arg_index)

def add_parent(self, parent: str):
"""
Expand Down
4 changes: 2 additions & 2 deletions frontend/Python/graph/transform/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def apply_classic_fusion(graph: Graph):
- None: Modifies the input graph in place.
"""
new_op_group = []
device = DeviceType.UNKNOW
device = DeviceType.CPU
# Run the first round of op fusion
classic_fuse_check(graph)
for op in graph.body:
Expand All @@ -126,7 +126,7 @@ def simply_fuse(graph: Graph):
- None: Modifies the input graph in place.
"""
new_op_group = []
device = DeviceType.UNKNOW
device = DeviceType.CPU
for op in graph.body:
if isinstance(op, PlaceholderOp):
continue
Expand Down
4 changes: 2 additions & 2 deletions frontend/Python/ops/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def call_op(node: CallOp, symbol_table: Dict[Tuple[str, int], ir.Operation]):
From Buddy CallOp to MLIR FUNC call operation.
"""
arguments = []
for arg in node.args:
input_node = symbol_table.get((str(arg), 0))
for i, arg in enumerate(node.args):
input_node = symbol_table.get((str(arg), node._args_index[i]))
memref_type = ir.MemRefType(input_node.type)
stride = []
shape = memref_type.shape
Expand Down
Loading

0 comments on commit 5093adf

Please sign in to comment.