diff --git a/mars/optimization/logical/core.py b/mars/optimization/logical/core.py index 1bf2d1e2ca..7809ade083 100644 --- a/mars/optimization/logical/core.py +++ b/mars/optimization/logical/core.py @@ -13,7 +13,6 @@ # limitations under the License. import functools import itertools -import weakref from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass @@ -92,8 +91,6 @@ def get_original_entity( class OptimizationRule(ABC): - _preds_to_remove = weakref.WeakKeyDictionary() - def __init__( self, graph: EntityGraph, @@ -217,35 +214,6 @@ def _replace_subgraph( for result in new_results: self._graph.results[result_indices[result.key]] = result - def _add_collapsable_predecessor(self, node: EntityType, predecessor: EntityType): - pred_original = self._records.get_original_entity(predecessor, predecessor) - if predecessor not in self._preds_to_remove: - self._preds_to_remove[pred_original] = {node} - else: - self._preds_to_remove[pred_original].add(node) - - def _remove_collapsable_predecessors(self, node: EntityType): - node = self._records.get_optimization_result(node) or node - preds_opt_to_remove = [] - for pred in self._graph.predecessors(node): - pred_original = self._records.get_original_entity(pred, pred) - pred_opt = self._records.get_optimization_result(pred, pred) - - if pred_opt in self._graph.results or pred_original in self._graph.results: - continue - affect_succ = self._preds_to_remove.get(pred_original) or [] - affect_succ_opt = [ - self._records.get_optimization_result(s, s) for s in affect_succ - ] - if all(s in affect_succ_opt for s in self._graph.successors(pred)): - preds_opt_to_remove.append((pred_original, pred_opt)) - - for pred_original, pred_opt in preds_opt_to_remove: - self._graph.remove_node(pred_opt) - self._records.append_record( - OptimizationRecord(pred_original, None, OptimizationRecordType.delete) - ) - class OperandBasedOptimizationRule(OptimizationRule): """ diff --git a/mars/optimization/logical/tests/test_core.py b/mars/optimization/logical/tests/test_core.py index d06b249bd9..6cfef989d7 100644 --- a/mars/optimization/logical/tests/test_core.py +++ b/mars/optimization/logical/tests/test_core.py @@ -157,10 +157,7 @@ def test_replace_null_subgraph(): c1.inputs.clear() c2.inputs.clear() - r.replace_subgraph( - None, - {key_to_node[op.key] for op in [s1, s2]} - ) + r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]}) assert g1.results == expected_results assert set(g1) == {key_to_node[n.key] for n in {c1, v1, c2, v2, v3}} expected_edges = { diff --git a/mars/optimization/logical/tileable/arithmetic_query.py b/mars/optimization/logical/tileable/arithmetic_query.py index 5ecf4a2945..3156f0801d 100644 --- a/mars/optimization/logical/tileable/arithmetic_query.py +++ b/mars/optimization/logical/tileable/arithmetic_query.py @@ -13,20 +13,22 @@ # limitations under the License. import weakref -from typing import NamedTuple, Optional +from abc import ABC +from typing import NamedTuple, Optional, Type, Set import numpy as np from pandas.api.types import is_scalar from .... import dataframe as md -from ....core import Tileable, get_output_types, ENTITY_TYPE +from ....core import Tileable, get_output_types, ENTITY_TYPE, TileableGraph +from ....core.graph import EntityGraph from ....dataframe.arithmetic.core import DataFrameUnaryUfunc, DataFrameBinopUfunc from ....dataframe.base.eval import DataFrameEval from ....dataframe.indexing.getitem import DataFrameIndex from ....dataframe.indexing.setitem import DataFrameSetitem -from ....typing import OperandType +from ....typing import OperandType, EntityType from ....utils import implements -from ..core import OptimizationRecord, OptimizationRecordType +from ..core import OptimizationRecord, OptimizationRecordType, OptimizationRecords from ..tileable.core import register_operand_based_optimization_rule from .core import OperandBasedOptimizationRule @@ -66,8 +68,70 @@ def builder(lhs: str, rhs: str): _extract_result_cache = weakref.WeakKeyDictionary() +class _EvalRewriteOptimizationRule(OperandBasedOptimizationRule, ABC): + def __init__( + self, + graph: EntityGraph, + records: OptimizationRecords, + optimizer_cls: Type["Optimizer"], + ): + super().__init__(graph, records, optimizer_cls) + self._marked_predecessors = dict() + + def _mark_predecessor(self, node: EntityType, predecessor: EntityType): + pred_original = self._records.get_original_entity(predecessor, predecessor) + if predecessor not in self._marked_predecessors: + self._marked_predecessors[pred_original] = {node} + else: + self._marked_predecessors[pred_original].add(node) + + def _find_nodes_to_remove(self, node: EntityType) -> Set[EntityType]: + node = self._records.get_optimization_result(node) or node + removed_nodes = {node} + results_set = set(self._graph.results) + removed_pairs = [] + for pred in self._graph.iter_predecessors(node): + pred_original = self._records.get_original_entity(pred, pred) + pred_opt = self._records.get_optimization_result(pred, pred) + + if pred_opt in results_set or pred_original in results_set: + continue + + affect_succ = self._marked_predecessors.get(pred_original) or [] + affect_succ_opt = [ + self._records.get_optimization_result(s, s) for s in affect_succ + ] + if all(s in affect_succ_opt for s in self._graph.iter_successors(pred)): + removed_pairs.append((pred_original, pred_opt)) + + for pred_original, pred_opt in removed_pairs: + removed_nodes.add(pred_opt) + self._records.append_record( + OptimizationRecord(pred_original, None, OptimizationRecordType.delete) + ) + return removed_nodes + + def _replace_with_new_node(self, original_node: EntityType, new_node: EntityType): + # Find all the nodes to remove + nodes_to_remove = self._find_nodes_to_remove(original_node) + + # Build the replaced subgraph + subgraph = TileableGraph() + subgraph.add_node(new_node) + + new_results = [new_node] if new_node in self._graph.results else None + self._replace_subgraph(subgraph, nodes_to_remove, new_results) + self._records.append_record( + OptimizationRecord( + self._records.get_original_entity(original_node, original_node), + new_node, + OptimizationRecordType.replace, + ) + ) + + @register_operand_based_optimization_rule([DataFrameUnaryUfunc, DataFrameBinopUfunc]) -class SeriesArithmeticToEval(OperandBasedOptimizationRule): +class SeriesArithmeticToEval(_EvalRewriteOptimizationRule): _var_counter = 0 @classmethod @@ -151,7 +215,7 @@ def _extract_unary(self, tileable) -> EvalExtractRecord: if in_tileable is None: return EvalExtractRecord() - self._add_collapsable_predecessor(tileable, op.inputs[0]) + self._mark_predecessor(tileable, op.inputs[0]) return EvalExtractRecord( in_tileable, _func_name_to_builder[func_name](expr), variables ) @@ -164,10 +228,10 @@ def _extract_binary(self, tileable) -> EvalExtractRecord: lhs_tileable, lhs_expr, lhs_vars = self._extract_eval_expression(op.lhs) if lhs_tileable is not None: - self._add_collapsable_predecessor(tileable, op.lhs) + self._mark_predecessor(tileable, op.lhs) rhs_tileable, rhs_expr, rhs_vars = self._extract_eval_expression(op.rhs) if rhs_tileable is not None: - self._add_collapsable_predecessor(tileable, op.rhs) + self._mark_predecessor(tileable, op.rhs) if lhs_expr is None or rhs_expr is None: return EvalExtractRecord() @@ -204,24 +268,10 @@ def apply_to_operand(self, op: OperandType): new_node = new_op.new_tileable( [opt_in_tileable], _key=node.key, _id=node.id, **node.params ).data + self._replace_with_new_node(node, new_node) - self._remove_collapsable_predecessors(node) - self._replace_node(node, new_node) - self._graph.add_edge(opt_in_tileable, new_node) - self._records.append_record( - OptimizationRecord(node, new_node, OptimizationRecordType.replace) - ) - - # check node if it's in result - try: - i = self._graph.results.index(node) - self._graph.results[i] = new_node - except ValueError: - pass - - -class _DataFrameEvalRewriteRule(OperandBasedOptimizationRule): +class _DataFrameEvalRewriteRule(_EvalRewriteOptimizationRule): @implements(OperandBasedOptimizationRule.match_operand) def match_operand(self, op: OperandType) -> bool: optimized_eval_op = self._get_optimized_eval_op(op) @@ -245,16 +295,6 @@ def _get_optimized_eval_op(self, op: OperandType) -> OperandType: def _get_input_columnar_node(self, op: OperandType) -> ENTITY_TYPE: raise NotImplementedError - def _update_op_node(self, old_node: ENTITY_TYPE, new_node: ENTITY_TYPE): - self._replace_node(old_node, new_node) - for in_tileable in new_node.inputs: - self._graph.add_edge(in_tileable, new_node) - - original_node = self._records.get_original_entity(old_node, old_node) - self._records.append_record( - OptimizationRecord(original_node, new_node, OptimizationRecordType.replace) - ) - @implements(OperandBasedOptimizationRule.apply_to_operand) def apply_to_operand(self, op: DataFrameIndex): node = op.outputs[0] @@ -268,10 +308,8 @@ def apply_to_operand(self, op: DataFrameIndex): new_node = new_op.new_tileable( [opt_in_tileable], _key=node.key, _id=node.id, **node.params ).data - - self._add_collapsable_predecessor(node, in_columnar_node) - self._remove_collapsable_predecessors(node) - self._update_op_node(node, new_node) + self._mark_predecessor(node, in_columnar_node) + self._replace_with_new_node(node, new_node) @register_operand_based_optimization_rule([DataFrameIndex]) @@ -360,7 +398,5 @@ def apply_to_operand(self, op: DataFrameIndex): new_node = new_op.new_tileable( pred_opt_node.inputs, _key=node.key, _id=node.id, **node.params ).data - - self._add_collapsable_predecessor(opt_node, pred_opt_node) - self._remove_collapsable_predecessors(opt_node) - self._update_op_node(opt_node, new_node) + self._mark_predecessor(opt_node, pred_opt_node) + self._replace_with_new_node(opt_node, new_node)