From 5093adf95d5e8d3a0597975a7e735af01806c6d8 Mon Sep 17 00:00:00 2001 From: Wu Xintong <13683168028@163.com> Date: Tue, 7 Jan 2025 20:20:38 +0800 Subject: [PATCH] [frontend] Add subgraph division and device type (#446) --- frontend/Python/frontend.py | 2 + frontend/Python/graph/graph.py | 14 +- frontend/Python/graph/graph_driver.py | 146 ++++++++++++++------ frontend/Python/graph/operation.py | 4 +- frontend/Python/graph/transform/fuse_ops.py | 4 +- frontend/Python/ops/func.py | 4 +- tests/Python/test_subgraph_division.py | 55 ++++++++ 7 files changed, 173 insertions(+), 56 deletions(-) create mode 100644 tests/Python/test_subgraph_division.py diff --git a/frontend/Python/frontend.py b/frontend/Python/frontend.py index c11843eab7..bc9fadc12c 100644 --- a/frontend/Python/frontend.py +++ b/frontend/Python/frontend.py @@ -45,6 +45,7 @@ from .graph import Graph, TensorDType, TensorMeta from .graph.operation import * from .graph.transform import maxpool2d_simplify +from .graph.type import * class DynamoCompiler: @@ -310,6 +311,7 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): fake_params, self._ops_registry, self._func_name, + DeviceType.CPU, self._verbose, ) param_nodes = [] diff --git a/frontend/Python/graph/graph.py b/frontend/Python/graph/graph.py index 751ddb0066..6491c36e51 100644 --- a/frontend/Python/graph/graph.py +++ b/frontend/Python/graph/graph.py @@ -105,6 +105,7 @@ def __init__( fake_params: List[TensorMeta], ops_registry: dict, func_name: str, + device: DeviceType = DeviceType.CPU, verbose=False, ) -> None: """ @@ -124,7 +125,7 @@ def __init__( self._inputs = inputs self.node_table: Dict[str, Op] = {} self._fake_params = fake_params - self.device = "cpu" + self.device = device self._imported_module = None self._verbose = verbose self._ops_registry = ops_registry @@ -244,11 +245,11 @@ def init_op_group(self): - None """ for i, op in enumerate(self._body): - if isinstance(op, PlaceholderOp): + if isinstance(op, PlaceholderOp) or isinstance(op, OutputOp): continue group = [op] subgraph_name = "subgraph{}".format(i) - self.group_map_device[subgraph_name] = DeviceType.UNKNOW + self.group_map_device[subgraph_name] = DeviceType.CPU self.op_groups[subgraph_name] = group def fuse_ops(self, pattern_list: List[FunctionType]): @@ -266,9 +267,6 @@ def fuse_ops(self, pattern_list: List[FunctionType]): # 1. fuse ops adapt for DSA(hardware dependent) # 2. common fuse strategy(hardware independent) - # Initialize operation groups - self.init_op_group() - # Apply fusion patterns for pattern_func in pattern_list: pattern_func(self) @@ -311,6 +309,8 @@ def lower_to_top_level_ir(self): self._inputs, self._func_name, self._ops_registry, + False, + self.device, verbose=self._verbose, ) self._imported_module = fx_importer.import_graph() @@ -424,6 +424,7 @@ def __init__( func_name: str, ops_registry: dict, do_param_pack: bool = False, + device: DeviceType = DeviceType.CPU, verbose=False, ): """ @@ -439,6 +440,7 @@ def __init__( ops_registry = {} self._symbol_table = {} self._body = body + self._device = device self._func_name = func_name self._params = params self._inputs = inputs diff --git a/frontend/Python/graph/graph_driver.py b/frontend/Python/graph/graph_driver.py index dd37aa12bd..545de8f4a3 100644 --- a/frontend/Python/graph/graph_driver.py +++ b/frontend/Python/graph/graph_driver.py @@ -21,6 +21,7 @@ # ===--------------------------------------------------------------------------- from mlir import ir +from collections import deque, defaultdict from .graph import Graph, GraphImporter, TensorMeta from .operation import FuncOp, CallOp, PlaceholderOp, OutputOp, GetItemOp @@ -40,6 +41,7 @@ class GraphDriver: - _subgraphs_outputs (dict): A dictionary mapping subgraph names to their output op's result. """ + def __init__(self, graph: Graph) -> None: """ Initialize the GraphDriver object with a given computational graph. @@ -52,6 +54,11 @@ def __init__(self, graph: Graph) -> None: - None """ self._graph = graph + self._subgraph_dependencies = { + subgraph_name: set() + for subgraph_name in list(self._graph.op_groups.keys()) + } + self._call_table = {} ( self._subgraphs, self._subgraphs_inputs, @@ -94,14 +101,15 @@ def build_subgraph_by_group(self): if isinstance(node, OutputOp): for arg in node.args: output_node.append(arg) - - # Identify outputs for each subgraph + + # Identify outputs for each subgraph and build dependencies between subgraphs for subgraph_name in self._graph.op_groups.keys(): subgraphs_outputs[subgraph_name] = [] for op in self._graph.op_groups[subgraph_name]: for key in subgraphs_inputs.keys(): if op.name in subgraphs_inputs[key]: subgraphs_outputs[subgraph_name].append(op.name) + self._subgraph_dependencies[subgraph_name].add(key) if (op.name in output_node) and ( op.name not in subgraphs_outputs[subgraph_name] ): @@ -112,6 +120,7 @@ def build_subgraph_by_group(self): for subgraph_name in self._graph.op_groups.keys(): subgraph_input = [] subgraph_body = [] + subgraph_device = self._graph.group_map_device[subgraph_name] # Construct input placeholder nodes for inp in subgraphs_inputs[subgraph_name]: @@ -127,11 +136,11 @@ def build_subgraph_by_group(self): if inp in node._parents: placeholder_node.add_children(op.name) subgraph_body.append(placeholder_node) - + # Add operations to subgraph body for op in self._graph.op_groups[subgraph_name]: subgraph_body.append(op) - + # Construct output node output_node = OutputOp() output_node.name = "output" @@ -142,7 +151,12 @@ def build_subgraph_by_group(self): # Create subgraph and add it to the dictionary subgraph = Graph( - subgraph_input, [], self._graph._ops_registry, subgraph_name, verbose=self._graph._verbose + subgraph_input, + [], + self._graph._ops_registry, + subgraph_name, + subgraph_device, + verbose=self._graph._verbose, ) subgraph.body = subgraph_body for op in subgraph_body: @@ -151,6 +165,38 @@ def build_subgraph_by_group(self): return subgraphs, subgraphs_inputs, subgraphs_outputs + def topological_sort_subgraph(self): + """ + Performs topological sorting on the subgraphs based on their dependencies. + Args: + - graph (Graph): The graph from which subgraphs are constructed. + Returns: + - list: A list of subgraph names in topological order if the graph is acyclic; otherwise, None. + """ + # Calculate in degree of each subgraph + in_degree = { + subgraph_name: 0 for subgraph_name in list(self._subgraphs.keys()) + } + for src, dests in self._subgraph_dependencies.items(): + for dest in dests: + in_degree[dest] += 1 + # Topological sorting + queue = deque([node for node in in_degree if in_degree[node] == 0]) + topo_order = [] + while queue: + node = queue.popleft() + topo_order.append(node) + for child in self._subgraph_dependencies[node]: + in_degree[child] -= 1 + if in_degree[child] == 0: + queue.append(child) + # TODO: If the custom subgraph partitioning is illegal, further partition the subgraph to make it valid. + return ( + topo_order + if len(topo_order) == len(list(self._subgraphs.keys())) + else None + ) + def construct_main_graph(self, do_param_pack=False): """ Constructs the main computational graph by incorporating subgraphs' call @@ -172,7 +218,7 @@ def construct_main_graph(self, do_param_pack=False): self._graph._fake_params, self._graph._ops_registry, self._graph._func_name, - self._graph._verbose + self._graph._verbose, ) # Adding FuncOp nodes for each subgraph @@ -189,53 +235,63 @@ def construct_main_graph(self, do_param_pack=False): func_node.tensor_meta["dtype"].append( self._graph.node_table[output].tensor_meta["dtype"] ) - main_graph.body.append(func_node) - + main_graph.add_node(func_node) + # Adding placeholder operations from the original graph for op in self._graph.body: if isinstance(op, PlaceholderOp): - main_graph.body.append(op) - - # TODO: analysis topology order to sort subgraph call. - if len(self._subgraphs) == 1: - # Adding CallOp to invoke the single subgraph + main_graph.add_node(op) + # Analysis topology order to sort subgraph call. + topo_order = self.topological_sort_subgraph() + if topo_order == None: + print("Error : Graph Partitioning is illegal!") + return None + # Adding CallOp to invoke the single subgraph + for i, subgraph_name in enumerate(topo_order): call_node = CallOp() - call_node.name = "call0" - call_node.call_func_name = list(self._subgraphs.keys())[0] + call_node.name = "call{}".format(i) + call_node.call_func_name = subgraph_name call_node.tensor_meta = {"shape": [], "dtype": []} - for inp in list(self._subgraphs_inputs.values())[0]: - call_node.add_argument(inp) - for output in list(self._subgraphs_outputs.values())[0]: + for inp in self._subgraphs_inputs[subgraph_name]: + if inp in main_graph.node_table: + call_node.add_argument(inp) + continue + for key, value in self._subgraphs_outputs.items(): + if inp in value: + call_node.add_argument( + arg=self._call_table[key].name, + arg_index=value.index(inp), + ) + break + for output in self._subgraphs_outputs[subgraph_name]: call_node.tensor_meta["shape"].append( self._graph.node_table[output].tensor_meta["shape"] ) call_node.tensor_meta["dtype"].append( self._graph.node_table[output].tensor_meta["dtype"] ) - main_graph.body.append(call_node) - - # Adding GetItemOps to retrieve individual output tensors - output_node = OutputOp() - for i, output in enumerate(list(self._subgraphs_outputs.values())[0]): - getitem_node = GetItemOp() - getitem_node.add_argument(call_node.name) - getitem_node.add_argument(i) - getitem_node.name = "getitem{}".format(i) - output_node.add_argument(getitem_node.name) - main_graph.body.append(getitem_node) - - # Marking the final output of the main graph - output_node.name = "output" - main_graph.body.append(output_node) - - # Importing the main graph - with ir.Location.unknown(ir.Context()): - main_importer = GraphImporter( - main_graph.body, - main_graph._fake_params, - main_graph._inputs, - main_graph._func_name, - main_graph._ops_registry, - do_param_pack, - ) - return main_importer.import_main_graph() + self._call_table[subgraph_name] = call_node + main_graph.add_node(call_node) + # Adding GetItemOps to retrieve individual output tensors + output_node = OutputOp() + for i, output in enumerate(self._subgraphs_outputs[topo_order[-1]]): + getitem_node = GetItemOp() + getitem_node.add_argument(call_node.name) + getitem_node.add_argument(i) + getitem_node.name = "getitem{}".format(i) + output_node.add_argument(getitem_node.name) + main_graph.add_node(getitem_node) + # Marking the final output of the main graph + output_node.name = "output" + main_graph.add_node(output_node) + # Importing the main graph + with ir.Location.unknown(ir.Context()): + main_importer = GraphImporter( + main_graph.body, + main_graph._fake_params, + main_graph._inputs, + main_graph._func_name, + main_graph._ops_registry, + do_param_pack, + ) + return main_importer.import_main_graph() diff --git a/frontend/Python/graph/operation.py b/frontend/Python/graph/operation.py index 218752abc0..1bb1ea0ade 100644 --- a/frontend/Python/graph/operation.py +++ b/frontend/Python/graph/operation.py @@ -86,8 +86,9 @@ def __init__(self) -> None: self._op_type: OpType = None self._children: List[str] = [] self._parents: List[str] = [] + self._args_index = [] - def add_argument(self, arg): + def add_argument(self, arg, arg_index=0): """ Add an input argument to the operation node. @@ -96,6 +97,7 @@ def add_argument(self, arg): The input argument to be added. """ self._arguments.append(arg) + self._args_index.append(arg_index) def add_parent(self, parent: str): """ diff --git a/frontend/Python/graph/transform/fuse_ops.py b/frontend/Python/graph/transform/fuse_ops.py index 992168aecc..2d22a1f543 100644 --- a/frontend/Python/graph/transform/fuse_ops.py +++ b/frontend/Python/graph/transform/fuse_ops.py @@ -103,7 +103,7 @@ def apply_classic_fusion(graph: Graph): - None: Modifies the input graph in place. """ new_op_group = [] - device = DeviceType.UNKNOW + device = DeviceType.CPU # Run the first round of op fusion classic_fuse_check(graph) for op in graph.body: @@ -126,7 +126,7 @@ def simply_fuse(graph: Graph): - None: Modifies the input graph in place. """ new_op_group = [] - device = DeviceType.UNKNOW + device = DeviceType.CPU for op in graph.body: if isinstance(op, PlaceholderOp): continue diff --git a/frontend/Python/ops/func.py b/frontend/Python/ops/func.py index a7dcc5e11b..e885809d82 100644 --- a/frontend/Python/ops/func.py +++ b/frontend/Python/ops/func.py @@ -59,8 +59,8 @@ def call_op(node: CallOp, symbol_table: Dict[Tuple[str, int], ir.Operation]): From Buddy CallOp to MLIR FUNC call operation. """ arguments = [] - for arg in node.args: - input_node = symbol_table.get((str(arg), 0)) + for i, arg in enumerate(node.args): + input_node = symbol_table.get((str(arg), node._args_index[i])) memref_type = ir.MemRefType(input_node.type) stride = [] shape = memref_type.shape diff --git a/tests/Python/test_subgraph_division.py b/tests/Python/test_subgraph_division.py new file mode 100644 index 0000000000..e3e282fcf5 --- /dev/null +++ b/tests/Python/test_subgraph_division.py @@ -0,0 +1,55 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s +import torch +import os + +from buddy.compiler.graph import GraphDriver +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.graph.type import DeviceType +from buddy.compiler.graph.operation import * + + +# Define the target function or model. +def foo(x, y): + return x * y + x + + +# Define the input data. +float32_in1 = torch.randn(10).to(torch.float32) +float32_in2 = torch.randn(10).to(torch.float32) + +dynamo_compiler = DynamoCompiler() +graphs = dynamo_compiler.importer(foo, *(float32_in1, float32_in2)) +graph = graphs[0] +graphs[0].lower_to_top_level_ir() +params = dynamo_compiler.imported_params[graph] + +#Divide the subgraphs +group = [graph._body[2]] +subgraph_name = "subgraph0" +graph.group_map_device[subgraph_name] = DeviceType.CPU +graph.op_groups[subgraph_name] = group + +new_group = [graph._body[3]] +subgraph_name = "subgraph1" +graph.group_map_device[subgraph_name] = DeviceType.CPU +graph.op_groups[subgraph_name] = new_group + +path_prefix = os.path.dirname(os.path.abspath(__file__)) +driver = GraphDriver(graph) +driver.subgraphs[0].lower_to_top_level_ir() +driver.subgraphs[1].lower_to_top_level_ir() +print(driver.construct_main_graph(True)) +# CHECK: module { +# CHECK-NEXT: func.func private @subgraph0(memref<10xf32, strided<[1], offset: ?>>, memref<10xf32, strided<[1], offset: ?>>) -> memref<10xf32> +# CHECK-NEXT: func.func private @subgraph1(memref<10xf32, strided<[1], offset: ?>>, memref<10xf32, strided<[1], offset: ?>>) -> memref<10xf32> +# CHECK-NEXT: func.func @forward(%arg0: memref<10xf32>, %arg1: memref<10xf32>) -> memref<10xf32> { +# CHECK-NEXT: %cast = memref.cast %arg0 : memref<10xf32> to memref<10xf32, strided<[1], offset: ?>> +# CHECK-NEXT: %cast_0 = memref.cast %arg1 : memref<10xf32> to memref<10xf32, strided<[1], offset: ?>> +# CHECK-NEXT: %0 = call @subgraph0(%cast, %cast_0) : (memref<10xf32, strided<[1], offset: ?>>, memref<10xf32, strided<[1], offset: ?>>) -> memref<10xf32> +# CHECK-NEXT: %cast_1 = memref.cast %0 : memref<10xf32> to memref<10xf32, strided<[1], offset: ?>> +# CHECK-NEXT: %cast_2 = memref.cast %arg0 : memref<10xf32> to memref<10xf32, strided<[1], offset: ?>> +# CHECK-NEXT: %1 = call @subgraph1(%cast_1, %cast_2) : (memref<10xf32, strided<[1], offset: ?>>, memref<10xf32, strided<[1], offset: ?>>) -> memref<10xf32> +# CHECK-NEXT: return %1 : memref<10xf32> +# CHECK-NEXT: } +# CHECK-NEXT: } +