Skip to content

Commit

Permalink
Optimize eval-setitem expressions as single eval expressions (#2695)
Browse files Browse the repository at this point in the history
  • Loading branch information
wjsi authored Feb 10, 2022
1 parent aaa23e8 commit 45eeb8d
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 67 deletions.
2 changes: 1 addition & 1 deletion mars/dataframe/indexing/setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def tile(cls, op: "DataFrameSetitem"):
]
value_chunk_index_values = [v.index_value for v in value.chunks]
is_identical = len(target_chunk_index_values) == len(
target_chunk_index_values
value_chunk_index_values
) and all(
c.key == v.key
for c, v in zip(target_chunk_index_values, value_chunk_index_values)
Expand Down
167 changes: 120 additions & 47 deletions mars/optimization/logical/tileable/arithmetic_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,30 @@
# limitations under the License.

import weakref
from typing import Optional, Tuple
from typing import NamedTuple, Optional

import numpy as np
from pandas.api.types import is_scalar

from .... import dataframe as md
from ....core import Tileable, get_output_types, ENTITY_TYPE
from ....dataframe.arithmetic.core import DataFrameUnaryUfunc, DataFrameBinopUfunc
from ....dataframe.base.eval import DataFrameEval
from ....dataframe.indexing.getitem import DataFrameIndex
from ....dataframe.indexing.setitem import DataFrameSetitem
from ....typing import OperandType
from ....utils import implements
from ..core import OptimizationRecord, OptimizationRecordType
from ..tileable.core import register_tileable_optimization_rule
from .core import OptimizationRule


class EvalExtractRecord(NamedTuple):
tileable: Optional[Tileable] = None
expr: Optional[str] = None
variables: Optional[dict] = None


def _get_binop_builder(op_str: str):
def builder(lhs: str, rhs: str):
return f"({lhs}) {op_str} ({rhs})"
Expand Down Expand Up @@ -60,9 +68,16 @@ def builder(lhs: str, rhs: str):

@register_tileable_optimization_rule([DataFrameUnaryUfunc, DataFrameBinopUfunc])
class SeriesArithmeticToEval(OptimizationRule):
_var_counter = 0

@classmethod
def _next_var_id(cls):
cls._var_counter += 1
return cls._var_counter

@implements(OptimizationRule.match)
def match(self, op: OperandType) -> bool:
_, expr = self._extract_eval_expression(op.outputs[0])
_, expr, _ = self._extract_eval_expression(op.outputs[0])
return expr is not None

@staticmethod
Expand Down Expand Up @@ -91,14 +106,17 @@ def _is_select_dataframe_column(tileable) -> bool:
)

@classmethod
def _extract_eval_expression(
cls, tileable
) -> Tuple[Optional[Tileable], Optional[str]]:
def _extract_eval_expression(cls, tileable) -> EvalExtractRecord:
if is_scalar(tileable):
return None, repr(tileable)
if isinstance(tileable, (int, bool, str, bytes, np.integer, np.bool_)):
return EvalExtractRecord(expr=repr(tileable))
else:
var_name = f"__eval_scalar_var{cls._next_var_id()}"
var_dict = {var_name: tileable}
return EvalExtractRecord(expr=f"@{var_name}", variables=var_dict)

if not isinstance(tileable, ENTITY_TYPE): # pragma: no cover
return None, None
return EvalExtractRecord()

if tileable in _extract_result_cache:
return _extract_result_cache[tileable]
Expand All @@ -109,69 +127,75 @@ def _extract_eval_expression(
result = cls._extract_unary(tileable)
elif isinstance(tileable.op, DataFrameBinopUfunc):
if tileable.op.fill_value is not None or tileable.op.level is not None:
result = None, None
result = EvalExtractRecord()
else:
result = cls._extract_binary(tileable)
else:
result = None, None
result = EvalExtractRecord()

_extract_result_cache[tileable] = result
return result

@classmethod
def _extract_column_select(
cls, tileable
) -> Tuple[Optional[Tileable], Optional[str]]:
return tileable.inputs[0], f"`{tileable.op.col_names}`"
def _extract_column_select(cls, tileable) -> EvalExtractRecord:
return EvalExtractRecord(tileable.inputs[0], f"`{tileable.op.col_names}`")

@classmethod
def _extract_unary(cls, tileable) -> Tuple[Optional[Tileable], Optional[str]]:
def _extract_unary(cls, tileable) -> EvalExtractRecord:
op = tileable.op
func_name = getattr(op, "_func_name") or getattr(op, "_bin_func_name")
if func_name not in _func_name_to_builder: # pragma: no cover
return None, None
return EvalExtractRecord()

in_tileable, expr = cls._extract_eval_expression(op.inputs[0])
in_tileable, expr, variables = cls._extract_eval_expression(op.inputs[0])
if in_tileable is None:
return None, None
return EvalExtractRecord()

cls._add_collapsable_predecessor(tileable, op.inputs[0])
return in_tileable, _func_name_to_builder[func_name](expr)
return EvalExtractRecord(
in_tileable, _func_name_to_builder[func_name](expr), variables
)

@classmethod
def _extract_binary(cls, tileable) -> Tuple[Optional[Tileable], Optional[str]]:
def _extract_binary(cls, tileable) -> EvalExtractRecord:
op = tileable.op
func_name = getattr(op, "_func_name", None) or getattr(op, "_bit_func_name")
if func_name not in _func_name_to_builder: # pragma: no cover
return None, None
return EvalExtractRecord()

lhs_tileable, lhs_expr = cls._extract_eval_expression(op.lhs)
lhs_tileable, lhs_expr, lhs_vars = cls._extract_eval_expression(op.lhs)
if lhs_tileable is not None:
cls._add_collapsable_predecessor(tileable, op.lhs)
rhs_tileable, rhs_expr = cls._extract_eval_expression(op.rhs)
rhs_tileable, rhs_expr, rhs_vars = cls._extract_eval_expression(op.rhs)
if rhs_tileable is not None:
cls._add_collapsable_predecessor(tileable, op.rhs)

if lhs_expr is None or rhs_expr is None:
return None, None
return EvalExtractRecord()
if (
lhs_tileable is not None
and rhs_tileable is not None
and lhs_tileable.key != rhs_tileable.key
):
return None, None
return EvalExtractRecord()

variables = (lhs_vars or dict()).copy()
variables.update(rhs_vars or dict())
in_tileable = next(t for t in [lhs_tileable, rhs_tileable] if t is not None)
return in_tileable, _func_name_to_builder[func_name](lhs_expr, rhs_expr)
return EvalExtractRecord(
in_tileable, _func_name_to_builder[func_name](lhs_expr, rhs_expr), variables
)

@implements(OptimizationRule.apply)
def apply(self, op: OperandType):
node = op.outputs[0]
in_tileable, expr = self._extract_eval_expression(node)
in_tileable, expr, variables = self._extract_eval_expression(node)

new_op = DataFrameEval(
_key=node.op.key,
_output_types=get_output_types(node),
expr=expr,
variables=variables or dict(),
parser="pandas",
is_query=False,
)
Expand All @@ -195,39 +219,41 @@ def apply(self, op: OperandType):
pass


@register_tileable_optimization_rule([DataFrameIndex])
class DataFrameBoolEvalToQuery(OptimizationRule):
def match(self, op: "DataFrameIndex") -> bool:
class _DataFrameEvalRewriteRule(OptimizationRule):
def match(self, op: OperandType) -> bool:
optimized_eval_op = self._get_optimized_eval_op(op)
if (
op.col_names is not None
or not isinstance(op.mask, md.Series)
or op.mask.dtype != bool
not isinstance(optimized_eval_op, DataFrameEval)
or optimized_eval_op.is_query
or optimized_eval_op.inputs[0].key != op.inputs[0].key
):
return False
optimized = self._records.get_optimization_result(op.mask)
mask_op = optimized.op if optimized is not None else op.mask.op
if not isinstance(mask_op, DataFrameEval) or mask_op.is_query:
return False
return True

def apply(self, op: "DataFrameIndex"):
def _build_new_eval_op(self, op: OperandType):
raise NotImplementedError

def _get_optimized_eval_op(self, op: OperandType) -> OperandType:
in_columnar_node = self._get_input_columnar_node(op)
optimized = self._records.get_optimization_result(in_columnar_node)
return optimized.op if optimized is not None else in_columnar_node.op

def _get_input_columnar_node(self, op: OperandType) -> ENTITY_TYPE:
raise NotImplementedError

def apply(self, op: DataFrameIndex):
node = op.outputs[0]
in_tileable = op.inputs[0]
in_columnar_node = self._get_input_columnar_node(op)

new_op = self._build_new_eval_op(op)
new_op._key = node.op.key

optimized = self._records.get_optimization_result(op.mask)
mask_op = optimized.op if optimized is not None else op.mask.op
new_op = DataFrameEval(
_key=node.op.key,
_output_types=get_output_types(node),
expr=mask_op.expr,
parser="pandas",
is_query=True,
)
new_node = new_op.new_tileable([in_tileable], **node.params).data
new_node._key = node.key
new_node._id = node.id

self._add_collapsable_predecessor(node, op.mask)
self._add_collapsable_predecessor(node, in_columnar_node)
self._remove_collapsable_predecessors(node)

self._replace_node(node, new_node)
Expand All @@ -242,3 +268,50 @@ def apply(self, op: "DataFrameIndex"):
self._graph.results[i] = new_node
except ValueError:
pass


@register_tileable_optimization_rule([DataFrameIndex])
class DataFrameBoolEvalToQuery(_DataFrameEvalRewriteRule):
def match(self, op: DataFrameIndex) -> bool:
if (
op.col_names is not None
or not isinstance(op.mask, md.Series)
or op.mask.dtype != bool
):
return False
return super().match(op)

def _get_input_columnar_node(self, op: OperandType) -> ENTITY_TYPE:
return op.mask

def _build_new_eval_op(self, op: OperandType):
in_eval_op = self._get_optimized_eval_op(op)
return DataFrameEval(
_output_types=get_output_types(op.outputs[0]),
expr=in_eval_op.expr,
variables=in_eval_op.variables,
parser="pandas",
is_query=True,
)


@register_tileable_optimization_rule([DataFrameSetitem])
class DataFrameEvalSetItemToEval(_DataFrameEvalRewriteRule):
def match(self, op: DataFrameSetitem):
if not isinstance(op.indexes, str) or not isinstance(op.value, md.Series):
return False
return super().match(op)

def _get_input_columnar_node(self, op: DataFrameSetitem) -> ENTITY_TYPE:
return op.value

def _build_new_eval_op(self, op: DataFrameSetitem):
in_eval_op = self._get_optimized_eval_op(op)
return DataFrameEval(
_output_types=get_output_types(op.outputs[0]),
expr=f"`{op.indexes}` = {in_eval_op.expr}",
variables=in_eval_op.variables,
parser="pandas",
is_query=False,
self_target=True,
)
60 changes: 41 additions & 19 deletions mars/optimization/logical/tileable/tests/test_arithmetic_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re

import numpy as np
import pandas as pd

Expand All @@ -22,6 +24,13 @@
from .. import optimize


_var_pattern = re.compile(r"@__eval_scalar_var\d+")


def _norm_vars(var_str):
return _var_pattern.sub("@scalar", var_str)


@enter_mode(build=True)
def test_arithmetic_query(setup):
raw = pd.DataFrame(np.random.rand(100, 10), columns=list("ABCDEFGHIJ"))
Expand Down Expand Up @@ -62,7 +71,6 @@ def test_arithmetic_query(setup):

pd.testing.assert_series_equal(df2.execute().fetch(), -raw["A"] + raw["B"] * 5)

raw = pd.DataFrame(np.random.rand(100, 10), columns=list("ABCDEFGHIJ"))
df1 = md.DataFrame(raw, chunk_size=10)
df2 = -df1["A"] + df1["B"] * 5 + 3 * df1["C"]
graph = TileableGraph([df1["A"].data, df2.data])
Expand All @@ -74,22 +82,6 @@ def test_arithmetic_query(setup):
r_df2, _r_col_a = fetch(execute(df2, df1["A"]))
pd.testing.assert_series_equal(r_df2, -raw["A"] + raw["B"] * 5 + 3 * raw["C"])

df1 = md.DataFrame(raw, chunk_size=10)
df2 = md.DataFrame(raw2, chunk_size=10)
df3 = df1.merge(df2, on="A", suffixes=("", "_"))
df3["K"] = df4 = df3["A"] * (1 - df3["B"])
graph = TileableGraph([df3.data])
next(TileableGraphBuilder(graph).build())
records = optimize(graph)
opt_df4 = records.get_optimization_result(df4.data)
assert opt_df4.op.expr == "(`A`) * ((1) - (`B`))"
assert len(graph) == 5
assert len([n for n in graph if isinstance(n.op, DataFrameEval)]) == 1

r_df3 = raw.merge(raw2, on="A", suffixes=("", "_"))
r_df3["K"] = r_df3["A"] * (1 - r_df3["B"])
pd.testing.assert_frame_equal(df3.execute().fetch(), r_df3)


@enter_mode(build=True)
def test_bool_eval_to_query(setup):
Expand All @@ -111,7 +103,7 @@ def test_bool_eval_to_query(setup):
opt_df2 = records.get_optimization_result(df2.data)
assert isinstance(opt_df2.op, DataFrameEval)
assert opt_df2.op.is_query
assert opt_df2.op.expr == "((`A`) > (0.5)) & ((`C`) < (0.5))"
assert _norm_vars(opt_df2.op.expr) == "((`A`) > (@scalar)) & ((`C`) < (@scalar))"

pd.testing.assert_frame_equal(
df2.execute().fetch(), raw[(raw["A"] > 0.5) & (raw["C"] < 0.5)]
Expand All @@ -138,7 +130,37 @@ def test_bool_eval_to_query(setup):
next(TileableGraphBuilder(graph).build())
records = optimize(graph)
opt_df2 = records.get_optimization_result(df2.data)
assert opt_df2.op.expr == "(`b`) < (Timestamp('2022-03-20 00:00:00'))"
assert _norm_vars(opt_df2.op.expr) == "(`b`) < (@scalar)"

r_df2 = fetch(execute(df2))
pd.testing.assert_frame_equal(r_df2, raw[raw.b < pd.Timestamp("2022-3-20")])


@enter_mode(build=True)
def test_eval_setitem_to_eval(setup):
raw = pd.DataFrame(np.random.rand(100, 10), columns=list("ABCDEFGHIJ"))
raw2 = pd.DataFrame(np.random.rand(100, 5), columns=list("ABCDE"))

# does not support non-eval value setting
df1 = md.DataFrame(raw, chunk_size=10)
df1["K"] = 345
graph = TileableGraph([df1.data])
next(TileableGraphBuilder(graph).build())
records = optimize(graph)
assert records.get_optimization_result(df1.data) is None

df1 = md.DataFrame(raw, chunk_size=10)
df2 = md.DataFrame(raw2, chunk_size=10)
df3 = df1.merge(df2, on="A", suffixes=("", "_"))
df3["K"] = df3["A"] * (1 - df3["B"])
graph = TileableGraph([df3.data])
next(TileableGraphBuilder(graph).build())
records = optimize(graph)
opt_df3 = records.get_optimization_result(df3.data)
assert opt_df3.op.expr == "`K` = (`A`) * ((1) - (`B`))"
assert len(graph) == 4
assert len([n for n in graph if isinstance(n.op, DataFrameEval)]) == 1

r_df3 = raw.merge(raw2, on="A", suffixes=("", "_"))
r_df3["K"] = r_df3["A"] * (1 - r_df3["B"])
pd.testing.assert_frame_equal(df3.execute().fetch(), r_df3)

0 comments on commit 45eeb8d

Please sign in to comment.