diff --git a/mars/core/entity/core.py b/mars/core/entity/core.py index 6a27ac65d2..d3234a59d5 100644 --- a/mars/core/entity/core.py +++ b/mars/core/entity/core.py @@ -42,6 +42,10 @@ def __init__(self, *args, **kwargs): def op(self): return self._op + @property + def outputs(self): + return self._op.outputs + @property def inputs(self): return self.op.inputs diff --git a/mars/optimization/logical/core.py b/mars/optimization/logical/core.py index ba49f825f0..7809ade083 100644 --- a/mars/optimization/logical/core.py +++ b/mars/optimization/logical/core.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -import weakref +import itertools from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass from enum import Enum from typing import Dict, List, Optional, Type, Set -from ...core import OperandType, EntityType, enter_mode +from ...core import OperandType, EntityType, enter_mode, Entity from ...core.graph import EntityGraph from ...utils import implements @@ -91,8 +91,6 @@ def get_original_entity( class OptimizationRule(ABC): - _preds_to_remove = weakref.WeakKeyDictionary() - def __init__( self, graph: EntityGraph, @@ -130,34 +128,91 @@ def _replace_node(self, original_node: EntityType, new_node: EntityType): for succ in successors: self._graph.add_edge(new_node, succ) - def _add_collapsable_predecessor(self, node: EntityType, predecessor: EntityType): - pred_original = self._records.get_original_entity(predecessor, predecessor) - if predecessor not in self._preds_to_remove: - self._preds_to_remove[pred_original] = {node} - else: - self._preds_to_remove[pred_original].add(node) - - def _remove_collapsable_predecessors(self, node: EntityType): - node = self._records.get_optimization_result(node) or node - preds_opt_to_remove = [] - for pred in self._graph.predecessors(node): - pred_original = self._records.get_original_entity(pred, pred) - pred_opt = self._records.get_optimization_result(pred, pred) - - if pred_opt in self._graph.results or pred_original in self._graph.results: - continue - affect_succ = self._preds_to_remove.get(pred_original) or [] - affect_succ_opt = [ - self._records.get_optimization_result(s, s) for s in affect_succ - ] - if all(s in affect_succ_opt for s in self._graph.successors(pred)): - preds_opt_to_remove.append((pred_original, pred_opt)) - - for pred_original, pred_opt in preds_opt_to_remove: - self._graph.remove_node(pred_opt) - self._records.append_record( - OptimizationRecord(pred_original, None, OptimizationRecordType.delete) - ) + def _replace_subgraph( + self, + graph: Optional[EntityGraph], + nodes_to_remove: Optional[Set[EntityType]], + new_results: Optional[List[Entity]] = None, + ): + """ + Replace the subgraph from the self._graph represented by a list of nodes with input graph. + It will delete the nodes in removed_nodes with all linked edges first, and then add (or update if it's still + existed in self._graph) the nodes and edges of the input graph. + + Parameters + ---------- + graph : EntityGraph, optional + The input graph. If it's none, no new node and edge will be added. + nodes_to_remove : Set[EntityType], optional + The nodes to be removed. All the edges connected with them are removed as well. + new_results : List[Entity], optional, default None + The new results to be replaced to the original by their keys. + + Raises + ------ + ValueError + 1. If the input key of the removed node's successor can't be found in the subgraph. + 2. Or some of the nodes of the subgraph are in removed ones. + 3. Or some of the removed nodes are also in the results. + 4. Or the key of the new result can't be found in the original results. + """ + affected_successors = set() + output_to_node = dict() + nodes_to_remove = nodes_to_remove or set() + new_results = new_results or list() + result_indices = { + result.key: idx for idx, result in enumerate(self._graph.results) + } + + if graph is not None: + # Add the output key -> node of the subgraph + for node in graph.iter_nodes(): + if node in nodes_to_remove: + raise ValueError(f"The node {node} is in the removed set") + for output in node.outputs: + output_to_node[output.key] = node + + # Add the output key -> node of the original graph + for node in self._graph.iter_nodes(): + if node not in nodes_to_remove: + for output in node.outputs: + output_to_node[output.key] = node + + # Check if the updated result is valid + for result in new_results: + if result.key not in result_indices: + raise ValueError(f"Unknown result {result} to replace") + if result.key not in output_to_node: + raise ValueError(f"The result {result} is missing in the updated graph") + + for node in nodes_to_remove: + for affected_successor in self._graph.iter_successors(node): + if affected_successor not in nodes_to_remove: + affected_successors.add(affected_successor) + # Check whether affected successors' inputs are in subgraph + for affected_successor in affected_successors: + for inp in affected_successor.inputs: + if inp.key not in output_to_node: + raise ValueError( + f"The output {inp} of node {affected_successor} is missing in the subgraph" + ) + # Here all the pre-check are passed, we start to replace the subgraph + for node in nodes_to_remove: + self._graph.remove_node(node) + + if graph is None: + return + + for node in graph.iter_nodes(): + self._graph.add_node(node) + + for node in itertools.chain(graph.iter_nodes(), affected_successors): + for inp in node.inputs: + pred_node = output_to_node[inp.key] + self._graph.add_edge(pred_node, node) + + for result in new_results: + self._graph.results[result_indices[result.key]] = result class OperandBasedOptimizationRule(OptimizationRule): diff --git a/mars/optimization/logical/tests/__init__.py b/mars/optimization/logical/tests/__init__.py new file mode 100644 index 0000000000..c71e83c08e --- /dev/null +++ b/mars/optimization/logical/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 1999-2021 Alibaba Group Holding Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mars/optimization/logical/tests/test_core.py b/mars/optimization/logical/tests/test_core.py new file mode 100644 index 0000000000..6cfef989d7 --- /dev/null +++ b/mars/optimization/logical/tests/test_core.py @@ -0,0 +1,228 @@ +# Copyright 1999-2021 Alibaba Group Holding Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +import pytest + + +from ..core import OptimizationRule +from .... import tensor as mt +from .... import dataframe as md + + +class _MockRule(OptimizationRule): + def apply(self) -> bool: + pass + + def replace_subgraph(self, graph, nodes_to_remove, new_results=None): + self._replace_subgraph(graph, nodes_to_remove, new_results) + + +def test_replace_tileable_subgraph(): + """ + Original Graph: + s1 ---> c1 ---> v1 ---> v4 ----> v6(output) <--- v5 <--- c5 <--- s5 + | ^ + | | + V | + v3 ------| + ^ + | + s2 ---> c2 ---> v2 + + Target Graph: + s1 ---> c1 ---> v1 ---> v7 ----> v8(output) <--- v5 <--- c5 <--- s5 + ^ + | + s2 ---> c2 ---> v2 + + The nodes [v3, v4, v6] will be removed. + Subgraph only contains [v7, v8] + """ + s1 = mt.random.randint(0, 100, size=(5, 4)) + v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5) + s2 = mt.random.randint(0, 100, size=(5, 4)) + v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5) + v3 = v1.add(v2) + v4 = v3.add(v1) + s5 = mt.random.randint(0, 100, size=(5, 4)) + v5 = md.DataFrame(s5, columns=list("ABCD"), chunk_size=4) + v6 = v5.sub(v4) + g1 = v6.build_graph() + v7 = v1.sub(v2) + v8 = v7.add(v5) + v8._key = v6.key + v8.outputs[0]._key = v6.key + g2 = v8.build_graph() + # Here we use a trick way to construct the subgraph for test only + key_to_node = dict() + for node in g2.iter_nodes(): + key_to_node[node.key] = node + for key, node in key_to_node.items(): + if key != v7.key and key != v8.key: + g2.remove_node(node) + r = _MockRule(g1, None, None) + for node in g1.iter_nodes(): + key_to_node[node.key] = node + + c1 = g1.successors(key_to_node[s1.key])[0] + c2 = g1.successors(key_to_node[s2.key])[0] + c5 = g1.successors(key_to_node[s5.key])[0] + + new_results = [v8.outputs[0]] + r.replace_subgraph(g2, {key_to_node[op.key] for op in [v3, v4, v6]}, new_results) + assert g1.results == new_results + for node in g1.iter_nodes(): + if node.key == v8.key: + key_to_node[v8.key] = node + break + expected_nodes = {s1, c1, v1, s2, c2, v2, s5, c5, v5, v7, v8} + assert set(g1) == {key_to_node[n.key] for n in expected_nodes} + + expected_edges = { + s1: [c1], + c1: [v1], + v1: [v7], + s2: [c2], + c2: [v2], + v2: [v7], + s5: [c5], + c5: [v5], + v5: [v8], + v7: [v8], + v8: [], + } + for pred, successors in expected_edges.items(): + pred_node = key_to_node[pred.key] + assert g1.count_successors(pred_node) == len(successors) + for successor in successors: + assert g1.has_successor(pred_node, key_to_node[successor.key]) + + +def test_replace_null_subgraph(): + """ + Original Graph: + s1 ---> c1 ---> v1 ---> v3(out) <--- v2 <--- c2 <--- s2 + + Target Graph: + c1 ---> v1 ---> v3(out) <--- v2 <--- c2 + + The nodes [s1, s2] will be removed. + Subgraph is None + """ + s1 = mt.random.randint(0, 100, size=(10, 4)) + v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5) + s2 = mt.random.randint(0, 100, size=(10, 4)) + v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5) + v3 = v1.add(v2) + g1 = v3.build_graph() + key_to_node = {node.key: node for node in g1.iter_nodes()} + c1 = g1.successors(key_to_node[s1.key])[0] + c2 = g1.successors(key_to_node[s2.key])[0] + r = _MockRule(g1, None, None) + expected_results = [v3.outputs[0]] + + # delete c5 s5 will fail + with pytest.raises(ValueError): + r.replace_subgraph( + None, {key_to_node[op.key] for op in [s1, s2]}, [v2.outputs[0]] + ) + + assert g1.results == expected_results + assert set(g1) == {key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, v3}} + expected_edges = { + s1: [c1], + c1: [v1], + v1: [v3], + s2: [c2], + c2: [v2], + v2: [v3], + v3: [], + } + for pred, successors in expected_edges.items(): + pred_node = key_to_node[pred.key] + assert g1.count_successors(pred_node) == len(successors) + for successor in successors: + assert g1.has_successor(pred_node, key_to_node[successor.key]) + + c1.inputs.clear() + c2.inputs.clear() + r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]}) + assert g1.results == expected_results + assert set(g1) == {key_to_node[n.key] for n in {c1, v1, c2, v2, v3}} + expected_edges = { + c1: [v1], + v1: [v3], + c2: [v2], + v2: [v3], + v3: [], + } + for pred, successors in expected_edges.items(): + pred_node = key_to_node[pred.key] + assert g1.count_successors(pred_node) == len(successors) + for successor in successors: + assert g1.has_successor(pred_node, key_to_node[successor.key]) + + +def test_replace_subgraph_without_removing_nodes(): + """ + Original Graph: + s1 ---> c1 ---> v1 ---> v4 <--- v2 <--- c2 <--- s2 + + Target Graph: + s1 ---> c1 ---> v1 ---> v4 <--- v2 <--- c2 <--- s2 + s3 ---> c3 ---> v3 + + Nothing will be removed. + Subgraph only contains [s3, c3, v3] + """ + s1 = mt.random.randint(0, 100, size=(10, 4)) + v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5) + s2 = mt.random.randint(0, 100, size=(10, 4)) + v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5) + v4 = v1.add(v2) + g1 = v4.build_graph() + + s3 = mt.random.randint(0, 100, size=(10, 4)) + v3 = md.DataFrame(s3, columns=list("ABCD"), chunk_size=5) + g2 = v3.build_graph() + key_to_node = { + node.key: node for node in itertools.chain(g1.iter_nodes(), g2.iter_nodes()) + } + expected_results = [v4.outputs[0]] + c1 = g1.successors(key_to_node[s1.key])[0] + c2 = g1.successors(key_to_node[s2.key])[0] + c3 = g2.successors(key_to_node[s3.key])[0] + r = _MockRule(g1, None, None) + r.replace_subgraph(g2, None) + assert g1.results == expected_results + assert set(g1) == { + key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, s3, c3, v3, v4} + } + expected_edges = { + s1: [c1], + c1: [v1], + v1: [v4], + s2: [c2], + c2: [v2], + v2: [v4], + s3: [c3], + c3: [v3], + v3: [], + v4: [], + } + for pred, successors in expected_edges.items(): + pred_node = key_to_node[pred.key] + assert g1.count_successors(pred_node) == len(successors) + for successor in successors: + assert g1.has_successor(pred_node, key_to_node[successor.key]) diff --git a/mars/optimization/logical/tileable/arithmetic_query.py b/mars/optimization/logical/tileable/arithmetic_query.py index 5ecf4a2945..604d9b7ac7 100644 --- a/mars/optimization/logical/tileable/arithmetic_query.py +++ b/mars/optimization/logical/tileable/arithmetic_query.py @@ -13,20 +13,27 @@ # limitations under the License. import weakref -from typing import NamedTuple, Optional +from abc import ABC +from typing import NamedTuple, Optional, Type, Set import numpy as np from pandas.api.types import is_scalar from .... import dataframe as md -from ....core import Tileable, get_output_types, ENTITY_TYPE +from ....core import Tileable, get_output_types, ENTITY_TYPE, TileableGraph +from ....core.graph import EntityGraph from ....dataframe.arithmetic.core import DataFrameUnaryUfunc, DataFrameBinopUfunc from ....dataframe.base.eval import DataFrameEval from ....dataframe.indexing.getitem import DataFrameIndex from ....dataframe.indexing.setitem import DataFrameSetitem -from ....typing import OperandType +from ....typing import OperandType, EntityType from ....utils import implements -from ..core import OptimizationRecord, OptimizationRecordType +from ..core import ( + OptimizationRecord, + OptimizationRecordType, + OptimizationRecords, + Optimizer, +) from ..tileable.core import register_operand_based_optimization_rule from .core import OperandBasedOptimizationRule @@ -66,8 +73,70 @@ def builder(lhs: str, rhs: str): _extract_result_cache = weakref.WeakKeyDictionary() +class _EvalRewriteOptimizationRule(OperandBasedOptimizationRule, ABC): + def __init__( + self, + graph: EntityGraph, + records: OptimizationRecords, + optimizer_cls: Type[Optimizer], + ): + super().__init__(graph, records, optimizer_cls) + self._marked_predecessors = dict() + + def _mark_predecessor(self, node: EntityType, predecessor: EntityType): + pred_original = self._records.get_original_entity(predecessor, predecessor) + if predecessor not in self._marked_predecessors: + self._marked_predecessors[pred_original] = {node} + else: + self._marked_predecessors[pred_original].add(node) + + def _find_nodes_to_remove(self, node: EntityType) -> Set[EntityType]: + node = self._records.get_optimization_result(node) or node + removed_nodes = {node} + results_set = set(self._graph.results) + removed_pairs = [] + for pred in self._graph.iter_predecessors(node): + pred_original = self._records.get_original_entity(pred, pred) + pred_opt = self._records.get_optimization_result(pred, pred) + + if pred_opt in results_set or pred_original in results_set: + continue + + affect_succ = self._marked_predecessors.get(pred_original) or [] + affect_succ_opt = [ + self._records.get_optimization_result(s, s) for s in affect_succ + ] + if all(s in affect_succ_opt for s in self._graph.iter_successors(pred)): + removed_pairs.append((pred_original, pred_opt)) + + for pred_original, pred_opt in removed_pairs: + removed_nodes.add(pred_opt) + self._records.append_record( + OptimizationRecord(pred_original, None, OptimizationRecordType.delete) + ) + return removed_nodes + + def _replace_with_new_node(self, original_node: EntityType, new_node: EntityType): + # Find all the nodes to remove + nodes_to_remove = self._find_nodes_to_remove(original_node) + + # Build the replaced subgraph + subgraph = TileableGraph() + subgraph.add_node(new_node) + + new_results = [new_node] if new_node in self._graph.results else None + self._replace_subgraph(subgraph, nodes_to_remove, new_results) + self._records.append_record( + OptimizationRecord( + self._records.get_original_entity(original_node, original_node), + new_node, + OptimizationRecordType.replace, + ) + ) + + @register_operand_based_optimization_rule([DataFrameUnaryUfunc, DataFrameBinopUfunc]) -class SeriesArithmeticToEval(OperandBasedOptimizationRule): +class SeriesArithmeticToEval(_EvalRewriteOptimizationRule): _var_counter = 0 @classmethod @@ -151,7 +220,7 @@ def _extract_unary(self, tileable) -> EvalExtractRecord: if in_tileable is None: return EvalExtractRecord() - self._add_collapsable_predecessor(tileable, op.inputs[0]) + self._mark_predecessor(tileable, op.inputs[0]) return EvalExtractRecord( in_tileable, _func_name_to_builder[func_name](expr), variables ) @@ -164,10 +233,10 @@ def _extract_binary(self, tileable) -> EvalExtractRecord: lhs_tileable, lhs_expr, lhs_vars = self._extract_eval_expression(op.lhs) if lhs_tileable is not None: - self._add_collapsable_predecessor(tileable, op.lhs) + self._mark_predecessor(tileable, op.lhs) rhs_tileable, rhs_expr, rhs_vars = self._extract_eval_expression(op.rhs) if rhs_tileable is not None: - self._add_collapsable_predecessor(tileable, op.rhs) + self._mark_predecessor(tileable, op.rhs) if lhs_expr is None or rhs_expr is None: return EvalExtractRecord() @@ -204,24 +273,10 @@ def apply_to_operand(self, op: OperandType): new_node = new_op.new_tileable( [opt_in_tileable], _key=node.key, _id=node.id, **node.params ).data + self._replace_with_new_node(node, new_node) - self._remove_collapsable_predecessors(node) - self._replace_node(node, new_node) - self._graph.add_edge(opt_in_tileable, new_node) - self._records.append_record( - OptimizationRecord(node, new_node, OptimizationRecordType.replace) - ) - - # check node if it's in result - try: - i = self._graph.results.index(node) - self._graph.results[i] = new_node - except ValueError: - pass - - -class _DataFrameEvalRewriteRule(OperandBasedOptimizationRule): +class _DataFrameEvalRewriteRule(_EvalRewriteOptimizationRule): @implements(OperandBasedOptimizationRule.match_operand) def match_operand(self, op: OperandType) -> bool: optimized_eval_op = self._get_optimized_eval_op(op) @@ -245,16 +300,6 @@ def _get_optimized_eval_op(self, op: OperandType) -> OperandType: def _get_input_columnar_node(self, op: OperandType) -> ENTITY_TYPE: raise NotImplementedError - def _update_op_node(self, old_node: ENTITY_TYPE, new_node: ENTITY_TYPE): - self._replace_node(old_node, new_node) - for in_tileable in new_node.inputs: - self._graph.add_edge(in_tileable, new_node) - - original_node = self._records.get_original_entity(old_node, old_node) - self._records.append_record( - OptimizationRecord(original_node, new_node, OptimizationRecordType.replace) - ) - @implements(OperandBasedOptimizationRule.apply_to_operand) def apply_to_operand(self, op: DataFrameIndex): node = op.outputs[0] @@ -268,10 +313,8 @@ def apply_to_operand(self, op: DataFrameIndex): new_node = new_op.new_tileable( [opt_in_tileable], _key=node.key, _id=node.id, **node.params ).data - - self._add_collapsable_predecessor(node, in_columnar_node) - self._remove_collapsable_predecessors(node) - self._update_op_node(node, new_node) + self._mark_predecessor(node, in_columnar_node) + self._replace_with_new_node(node, new_node) @register_operand_based_optimization_rule([DataFrameIndex]) @@ -360,7 +403,5 @@ def apply_to_operand(self, op: DataFrameIndex): new_node = new_op.new_tileable( pred_opt_node.inputs, _key=node.key, _id=node.id, **node.params ).data - - self._add_collapsable_predecessor(opt_node, pred_opt_node) - self._remove_collapsable_predecessors(opt_node) - self._update_op_node(opt_node, new_node) + self._mark_predecessor(opt_node, pred_opt_node) + self._replace_with_new_node(opt_node, new_node)