From ce831341cfa813694dc6f17238351e79a9d09cc0 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Thu, 8 Aug 2024 13:26:26 +0200 Subject: [PATCH 01/10] feat: dask index workaround --- narwhals/_dask/dataframe.py | 27 +++++++++++-- narwhals/_dask/expr.py | 25 ++++++++++++ narwhals/_dask/utils.py | 15 ++++++++ tests/expr_and_series/drop_nulls_test.py | 49 ++++++++++++++++-------- tests/expr_and_series/sort_test.py | 10 +++-- 5 files changed, 102 insertions(+), 24 deletions(-) diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 88e1c5b6e..fc8b199b2 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -7,8 +7,10 @@ from typing import Sequence from narwhals._dask.utils import parse_exprs_and_named_exprs +from narwhals._dask.utils import set_axis from narwhals._pandas_like.utils import translate_dtype from narwhals.dependencies import get_dask_dataframe +from narwhals.dependencies import get_dask_expr from narwhals.dependencies import get_pandas from narwhals.utils import Implementation from narwhals.utils import flatten @@ -47,8 +49,9 @@ def _from_native_dataframe(self, df: Any) -> Self: def with_columns(self, *exprs: DaskExpr, **named_exprs: DaskExpr) -> Self: df = self._native_dataframe + index = df.index new_series = parse_exprs_and_named_exprs(self, *exprs, **named_exprs) - df = df.assign(**new_series) + df = df.assign(**{k: set_axis(v, index) for k, v in new_series.items()}) return self._from_native_dataframe(df) def collect(self) -> Any: @@ -106,8 +109,26 @@ def select( ) return self._from_native_dataframe(df) - df = self._native_dataframe.assign(**new_series).loc[:, list(new_series.keys())] - return self._from_native_dataframe(df) + pd = get_pandas() + de = get_dask_expr() + + col_order = list(new_series.keys()) + + index = next( # pragma: no cover + s for s in new_series.values() if not isinstance(s, de._collection.Scalar) + ).index + + new_series = { + k: set_axis(v, index) + for k, v in sorted( + new_series.items(), + key=lambda item: isinstance(item[1], de._collection.Scalar), + ) + } + + return self._from_native_dataframe( + dd.from_pandas(pd.DataFrame()).assign(**new_series).loc[:, col_order] + ) def drop_nulls(self) -> Self: return self._from_native_dataframe(self._native_dataframe.dropna()) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 27a574a6a..257f8551e 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -432,6 +432,31 @@ def clip( returns_scalar=False, ) + def sort(self: Self, *, descending: bool = False, nulls_last: bool = False) -> Self: + na_position = "last" if nulls_last else "first" + + def func(_input: Any, ascending: bool, na_position: bool) -> Any: # noqa: FBT001 + name = _input.name + + return _input.to_frame(name=name).sort_values( + by=name, ascending=ascending, na_position=na_position + )[name] + + return self._from_call( + func, + "sort", + not descending, + na_position, + returns_scalar=False, + ) + + def drop_nulls(self: Self) -> Self: + return self._from_call( + lambda _input: _input.dropna(), + "drop_nulls", + returns_scalar=False, + ) + @property def str(self: Self) -> DaskExprStringNamespace: return DaskExprStringNamespace(self) diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py index 2fca67a2d..6844ad87b 100644 --- a/narwhals/_dask/utils.py +++ b/narwhals/_dask/utils.py @@ -2,12 +2,19 @@ from typing import TYPE_CHECKING from typing import Any +from typing import TypeVar from narwhals.dependencies import get_dask_expr if TYPE_CHECKING: + from dask_expr._collection import Index + from dask_expr._collection import Scalar + from dask_expr._collection import Series + from narwhals._dask.dataframe import DaskLazyFrame + T = TypeVar("T", Scalar, Series) + def maybe_evaluate(df: DaskLazyFrame, obj: Any) -> Any: from narwhals._dask.expr import DaskExpr @@ -64,3 +71,11 @@ def parse_exprs_and_named_exprs( raise AssertionError(msg) results[name] = _results[0] return results + + +def set_axis(obj: T, index: Index) -> T: + de = get_dask_expr() + if isinstance(obj, de._collection.Scalar): + return obj + else: + return de._expr.AssignIndex(obj, index) diff --git a/tests/expr_and_series/drop_nulls_test.py b/tests/expr_and_series/drop_nulls_test.py index 26455615d..93548eadc 100644 --- a/tests/expr_and_series/drop_nulls_test.py +++ b/tests/expr_and_series/drop_nulls_test.py @@ -7,27 +7,26 @@ import narwhals as nw from tests.utils import compare_dicts +data = { + "a": [1, 2, None], + "b": [3, 4, 5], + "c": [None, None, None], + "d": [6, None, None], +} -def test_drop_nulls(constructor: Any, request: Any) -> None: - if "dask" in str(constructor): - request.applymarker(pytest.mark.xfail) - data = { - "A": [1, 2, None, 4], - "B": [5, 6, 7, 8], - "C": [None, None, None, None], - "D": [9, 10, 11, 12], - } +def test_drop_nulls(constructor: Any) -> None: df = nw.from_native(constructor(data)) - result_a = df.select(nw.col("A").drop_nulls()) - result_b = df.select(nw.col("B").drop_nulls()) - result_c = df.select(nw.col("C").drop_nulls()) - result_d = df.select(nw.col("D").drop_nulls()) - expected_a = {"A": [1.0, 2.0, 4.0]} - expected_b = {"B": [5, 6, 7, 8]} - expected_c = {"C": []} # type: ignore[var-annotated] - expected_d = {"D": [9, 10, 11, 12]} + result_a = df.select(nw.col("a").drop_nulls()) + result_b = df.select(nw.col("b").drop_nulls()) + result_c = df.select(nw.col("c").drop_nulls()) + result_d = df.select(nw.col("d").drop_nulls()) + + expected_a = {"a": [1.0, 2.0]} + expected_b = {"b": [3, 4, 5]} + expected_c = {"c": []} # type: ignore[var-annotated] + expected_d = {"d": [6]} compare_dicts(result_a, expected_a) compare_dicts(result_b, expected_b) @@ -35,6 +34,22 @@ def test_drop_nulls(constructor: Any, request: Any) -> None: compare_dicts(result_d, expected_d) +def test_drop_nulls_broadcast(constructor: Any, request: Any) -> None: + if "dask" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) + result = df.select(nw.col("a").drop_nulls(), nw.col("d").drop_nulls()) + expected = {"a": [1.0, 2.0], "d": [6, 6]} + compare_dicts(result, expected) + + +def test_drop_nulls_invalid(constructor: Any) -> None: + df = nw.from_native(constructor(data)).lazy() + + with pytest.raises(Exception): # noqa: B017, PT011 + df.select(nw.col("a").drop_nulls(), nw.col("b").drop_nulls()).collect() + + def test_drop_nulls_series(constructor_eager: Any) -> None: data = { "A": [1, 2, None, 4], diff --git a/tests/expr_and_series/sort_test.py b/tests/expr_and_series/sort_test.py index f06e21f74..e8b112d71 100644 --- a/tests/expr_and_series/sort_test.py +++ b/tests/expr_and_series/sort_test.py @@ -17,16 +17,18 @@ ], ) def test_sort_expr( - constructor_eager: Any, descending: Any, nulls_last: Any, expected: Any + constructor: Any, descending: Any, nulls_last: Any, expected: Any ) -> None: - df = nw.from_native(constructor_eager(data), eager_only=True) + df = nw.from_native(constructor(data)).lazy() result = nw.to_native( df.select( "a", nw.col("b").sort(descending=descending, nulls_last=nulls_last), - ) + ).collect() ) - assert result.equals(constructor_eager(expected)) + + expected_df = nw.to_native(nw.from_native(constructor(expected)).lazy().collect()) + assert result.equals(expected_df) @pytest.mark.parametrize( From 8677b8f7e3858c061f0d16913ba141da5e2e2ad0 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Thu, 8 Aug 2024 13:30:32 +0200 Subject: [PATCH 02/10] test refactor --- tests/expr_and_series/drop_nulls_test.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/tests/expr_and_series/drop_nulls_test.py b/tests/expr_and_series/drop_nulls_test.py index 93548eadc..e3863bb03 100644 --- a/tests/expr_and_series/drop_nulls_test.py +++ b/tests/expr_and_series/drop_nulls_test.py @@ -51,23 +51,16 @@ def test_drop_nulls_invalid(constructor: Any) -> None: def test_drop_nulls_series(constructor_eager: Any) -> None: - data = { - "A": [1, 2, None, 4], - "B": [5, 6, 7, 8], - "C": [None, None, None, None], - "D": [9, 10, 11, 12], - } - df = nw.from_native(constructor_eager(data), eager_only=True) - result_a = df.select(df["A"].drop_nulls()) - result_b = df.select(df["B"].drop_nulls()) - result_c = df.select(df["C"].drop_nulls()) - result_d = df.select(df["D"].drop_nulls()) - expected_a = {"A": [1.0, 2.0, 4.0]} - expected_b = {"B": [5, 6, 7, 8]} - expected_c = {"C": []} # type: ignore[var-annotated] - expected_d = {"D": [9, 10, 11, 12]} + result_a = df.select(df["a"].drop_nulls()) + result_b = df.select(df["b"].drop_nulls()) + result_c = df.select(df["c"].drop_nulls()) + result_d = df.select(df["d"].drop_nulls()) + expected_a = {"a": [1.0, 2.0]} + expected_b = {"b": [3, 4, 5]} + expected_c = {"c": []} # type: ignore[var-annotated] + expected_d = {"d": [6]} compare_dicts(result_a, expected_a) compare_dicts(result_b, expected_b) From 1dae15fdf191af763df59d2edd0b229e8218c8d5 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sat, 10 Aug 2024 20:10:44 +0200 Subject: [PATCH 03/10] avoid sorting new series --- narwhals/_dask/dataframe.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index fc8b199b2..aaebb8796 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -114,20 +114,15 @@ def select( col_order = list(new_series.keys()) - index = next( # pragma: no cover + left_most_series = next( # pragma: no cover s for s in new_series.values() if not isinstance(s, de._collection.Scalar) - ).index - - new_series = { - k: set_axis(v, index) - for k, v in sorted( - new_series.items(), - key=lambda item: isinstance(item[1], de._collection.Scalar), - ) - } + ) + index = left_most_series.index return self._from_native_dataframe( - dd.from_pandas(pd.DataFrame()).assign(**new_series).loc[:, col_order] + left_most_series.to_frame() + .assign(**{k: set_axis(v, index) for k, v in new_series.items()}) + .loc[:, col_order] ) def drop_nulls(self) -> Self: From a80c94a0322069e59cf397c663655670b02e839a Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Tue, 27 Aug 2024 09:47:20 +0200 Subject: [PATCH 04/10] add modifies_index flag --- narwhals/_dask/dataframe.py | 17 ++++ narwhals/_dask/expr.py | 134 +++++++++++++++++++++++++---- narwhals/_dask/namespace.py | 12 ++- narwhals/_dask/selectors.py | 6 ++ tests/expr_and_series/head_test.py | 2 - tests/expr_and_series/tail_test.py | 2 - tests/frame/with_columns_test.py | 16 ++++ 7 files changed, 169 insertions(+), 20 deletions(-) diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 40caab77b..cf0064d75 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -49,6 +49,15 @@ def _from_native_frame(self, df: Any) -> Self: return self.__class__(df, backend_version=self._backend_version) def with_columns(self, *exprs: DaskExpr, **named_exprs: DaskExpr) -> Self: + n_modifies_index = sum( + getattr(e, "_modifies_index", 0) + for e in list(exprs) + list(named_exprs.values()) + ) + + if n_modifies_index > 0: + msg = "Expressions that modify the index are not supported in `with_columns`." + raise ValueError(msg) + new_series = parse_exprs_and_named_exprs(self, *exprs, **named_exprs) return self._from_native_frame(self._native_frame.assign(**new_series)) @@ -101,6 +110,14 @@ def select( # This is a simple slice => fastpath! return self._from_native_frame(self._native_frame.loc[:, exprs]) + n_modifies_index = sum( + getattr(e, "_modifies_index", 0) + for e in list(exprs) + list(named_exprs.values()) + ) + if n_modifies_index > 1: + msg = "Found multiple expressions that modify the index" + raise ValueError(msg) + new_series = parse_exprs_and_named_exprs(self, *exprs, **named_exprs) if not new_series: diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index b188d63a9..25c72607c 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -34,6 +34,7 @@ def __init__( # Whether the expression is a length-1 Series resulting from # a reduction, such as `nw.col('a').sum()` returns_scalar: bool, + modifies_index: bool, backend_version: tuple[int, ...], ) -> None: self._call = call @@ -42,6 +43,7 @@ def __init__( self._root_names = root_names self._output_names = output_names self._returns_scalar = returns_scalar + self._modifies_index = modifies_index self._backend_version = backend_version def __narwhals_expr__(self) -> None: ... @@ -68,6 +70,7 @@ def func(df: DaskLazyFrame) -> list[Any]: root_names=list(column_names), output_names=list(column_names), returns_scalar=False, + modifies_index=False, backend_version=backend_version, ) @@ -78,6 +81,7 @@ def _from_call( expr_name: str, *args: Any, returns_scalar: bool, + modifies_index: bool, **kwargs: Any, ) -> Self: def func(df: DaskLazyFrame) -> list[Any]: @@ -127,6 +131,7 @@ def func(df: DaskLazyFrame) -> list[Any]: root_names=root_names, output_names=output_names, returns_scalar=self._returns_scalar or returns_scalar, + modifies_index=self._modifies_index or modifies_index, backend_version=self._backend_version, ) @@ -142,6 +147,7 @@ def func(df: DaskLazyFrame) -> list[Any]: root_names=self._root_names, output_names=[name], returns_scalar=self._returns_scalar, + modifies_index=self._modifies_index, backend_version=self._backend_version, ) @@ -151,6 +157,7 @@ def __add__(self, other: Any) -> Self: "__add__", other, returns_scalar=False, + modifies_index=False, ) def __radd__(self, other: Any) -> Self: @@ -159,6 +166,7 @@ def __radd__(self, other: Any) -> Self: "__radd__", other, returns_scalar=False, + modifies_index=False, ) def __sub__(self, other: Any) -> Self: @@ -167,6 +175,7 @@ def __sub__(self, other: Any) -> Self: "__sub__", other, returns_scalar=False, + modifies_index=False, ) def __rsub__(self, other: Any) -> Self: @@ -175,6 +184,7 @@ def __rsub__(self, other: Any) -> Self: "__rsub__", other, returns_scalar=False, + modifies_index=False, ) def __mul__(self, other: Any) -> Self: @@ -183,6 +193,7 @@ def __mul__(self, other: Any) -> Self: "__mul__", other, returns_scalar=False, + modifies_index=False, ) def __rmul__(self, other: Any) -> Self: @@ -191,6 +202,7 @@ def __rmul__(self, other: Any) -> Self: "__rmul__", other, returns_scalar=False, + modifies_index=False, ) def __truediv__(self, other: Any) -> Self: @@ -199,6 +211,7 @@ def __truediv__(self, other: Any) -> Self: "__truediv__", other, returns_scalar=False, + modifies_index=False, ) def __rtruediv__(self, other: Any) -> Self: @@ -207,6 +220,7 @@ def __rtruediv__(self, other: Any) -> Self: "__rtruediv__", other, returns_scalar=False, + modifies_index=False, ) def __floordiv__(self, other: Any) -> Self: @@ -215,6 +229,7 @@ def __floordiv__(self, other: Any) -> Self: "__floordiv__", other, returns_scalar=False, + modifies_index=False, ) def __rfloordiv__(self, other: Any) -> Self: @@ -223,6 +238,7 @@ def __rfloordiv__(self, other: Any) -> Self: "__rfloordiv__", other, returns_scalar=False, + modifies_index=False, ) def __pow__(self, other: Any) -> Self: @@ -231,6 +247,7 @@ def __pow__(self, other: Any) -> Self: "__pow__", other, returns_scalar=False, + modifies_index=False, ) def __rpow__(self, other: Any) -> Self: @@ -239,6 +256,7 @@ def __rpow__(self, other: Any) -> Self: "__rpow__", other, returns_scalar=False, + modifies_index=False, ) def __mod__(self, other: Any) -> Self: @@ -247,6 +265,7 @@ def __mod__(self, other: Any) -> Self: "__mod__", other, returns_scalar=False, + modifies_index=False, ) def __rmod__(self, other: Any) -> Self: @@ -255,6 +274,7 @@ def __rmod__(self, other: Any) -> Self: "__rmod__", other, returns_scalar=False, + modifies_index=False, ) def __eq__(self, other: DaskExpr) -> Self: # type: ignore[override] @@ -263,6 +283,7 @@ def __eq__(self, other: DaskExpr) -> Self: # type: ignore[override] "__eq__", other, returns_scalar=False, + modifies_index=False, ) def __ne__(self, other: DaskExpr) -> Self: # type: ignore[override] @@ -271,6 +292,7 @@ def __ne__(self, other: DaskExpr) -> Self: # type: ignore[override] "__ne__", other, returns_scalar=False, + modifies_index=False, ) def __ge__(self, other: DaskExpr) -> Self: @@ -279,6 +301,7 @@ def __ge__(self, other: DaskExpr) -> Self: "__ge__", other, returns_scalar=False, + modifies_index=False, ) def __gt__(self, other: DaskExpr) -> Self: @@ -287,6 +310,7 @@ def __gt__(self, other: DaskExpr) -> Self: "__gt__", other, returns_scalar=False, + modifies_index=False, ) def __le__(self, other: DaskExpr) -> Self: @@ -295,6 +319,7 @@ def __le__(self, other: DaskExpr) -> Self: "__le__", other, returns_scalar=False, + modifies_index=False, ) def __lt__(self, other: DaskExpr) -> Self: @@ -303,6 +328,7 @@ def __lt__(self, other: DaskExpr) -> Self: "__lt__", other, returns_scalar=False, + modifies_index=False, ) def __and__(self, other: DaskExpr) -> Self: @@ -311,6 +337,7 @@ def __and__(self, other: DaskExpr) -> Self: "__and__", other, returns_scalar=False, + modifies_index=False, ) def __rand__(self, other: DaskExpr) -> Self: # pragma: no cover @@ -319,6 +346,7 @@ def __rand__(self, other: DaskExpr) -> Self: # pragma: no cover "__rand__", other, returns_scalar=False, + modifies_index=False, ) def __or__(self, other: DaskExpr) -> Self: @@ -327,6 +355,7 @@ def __or__(self, other: DaskExpr) -> Self: "__or__", other, returns_scalar=False, + modifies_index=False, ) def __ror__(self, other: DaskExpr) -> Self: # pragma: no cover @@ -335,6 +364,7 @@ def __ror__(self, other: DaskExpr) -> Self: # pragma: no cover "__ror__", other, returns_scalar=False, + modifies_index=False, ) def __invert__(self: Self) -> Self: @@ -342,6 +372,7 @@ def __invert__(self: Self) -> Self: lambda _input: _input.__invert__(), "__invert__", returns_scalar=False, + modifies_index=False, ) def mean(self) -> Self: @@ -349,6 +380,7 @@ def mean(self) -> Self: lambda _input: _input.mean(), "mean", returns_scalar=True, + modifies_index=False, ) def min(self) -> Self: @@ -356,6 +388,7 @@ def min(self) -> Self: lambda _input: _input.min(), "min", returns_scalar=True, + modifies_index=False, ) def max(self) -> Self: @@ -363,6 +396,7 @@ def max(self) -> Self: lambda _input: _input.max(), "max", returns_scalar=True, + modifies_index=False, ) def std(self, ddof: int = 1) -> Self: @@ -371,6 +405,7 @@ def std(self, ddof: int = 1) -> Self: "std", ddof, returns_scalar=True, + modifies_index=False, ) def shift(self, n: int) -> Self: @@ -379,6 +414,7 @@ def shift(self, n: int) -> Self: "shift", n, returns_scalar=False, + modifies_index=False, ) def cum_sum(self) -> Self: @@ -386,6 +422,7 @@ def cum_sum(self) -> Self: lambda _input: _input.cumsum(), "cum_sum", returns_scalar=False, + modifies_index=False, ) def is_between( @@ -407,6 +444,7 @@ def is_between( upper_bound, closed, returns_scalar=False, + modifies_index=False, ) def sum(self) -> Self: @@ -414,6 +452,7 @@ def sum(self) -> Self: lambda _input: _input.sum(), "sum", returns_scalar=True, + modifies_index=False, ) def count(self) -> Self: @@ -421,6 +460,7 @@ def count(self) -> Self: lambda _input: _input.count(), "count", returns_scalar=True, + modifies_index=False, ) def round(self, decimals: int) -> Self: @@ -429,18 +469,15 @@ def round(self, decimals: int) -> Self: "round", decimals, returns_scalar=False, + modifies_index=False, ) - def head(self) -> NoReturn: - # We can't (yet?) allow methods which modify the index - msg = "`Expr.head` is not supported for the Dask backend. Please use `LazyFrame.head` instead." - raise NotImplementedError(msg) - def abs(self) -> Self: return self._from_call( lambda _input: _input.abs(), "abs", returns_scalar=False, + modifies_index=False, ) def all(self) -> Self: @@ -450,6 +487,7 @@ def all(self) -> Self: ), "all", returns_scalar=True, + modifies_index=False, ) def any(self) -> Self: @@ -457,6 +495,7 @@ def any(self) -> Self: lambda _input: _input.any(axis=0, skipna=True, split_every=False), "any", returns_scalar=True, + modifies_index=False, ) def fill_null(self, value: Any) -> DaskExpr: @@ -465,6 +504,7 @@ def fill_null(self, value: Any) -> DaskExpr: "fillna", value, returns_scalar=False, + modifies_index=False, ) def clip( @@ -478,6 +518,7 @@ def clip( lower_bound, upper_bound, returns_scalar=False, + modifies_index=False, ) def diff(self: Self) -> Self: @@ -485,6 +526,7 @@ def diff(self: Self) -> Self: lambda _input: _input.diff(), "diff", returns_scalar=False, + modifies_index=False, ) def n_unique(self: Self) -> Self: @@ -492,6 +534,7 @@ def n_unique(self: Self) -> Self: lambda _input: _input.nunique(dropna=False), "n_unique", returns_scalar=True, + modifies_index=False, ) def is_null(self: Self) -> Self: @@ -499,6 +542,7 @@ def is_null(self: Self) -> Self: lambda _input: _input.isna(), "is_null", returns_scalar=False, + modifies_index=False, ) def len(self: Self) -> Self: @@ -506,6 +550,7 @@ def len(self: Self) -> Self: lambda _input: _input.size, "len", returns_scalar=True, + modifies_index=False, ) def quantile( @@ -519,6 +564,7 @@ def quantile( "quantile", quantile, returns_scalar=True, + modifies_index=False, ) else: msg = "`higher`, `lower`, `midpoint`, `nearest` - interpolation methods are not supported by Dask. Please use `linear` instead." @@ -539,6 +585,7 @@ def func(_input: Any) -> Any: func, "is_first_distinct", returns_scalar=False, + modifies_index=False, ) def is_last_distinct(self: Self) -> Self: @@ -554,6 +601,7 @@ def func(_input: Any) -> Any: func, "is_last_distinct", returns_scalar=False, + modifies_index=False, ) def is_duplicated(self: Self) -> Self: @@ -567,6 +615,7 @@ def func(_input: Any) -> Any: func, "is_duplicated", returns_scalar=False, + modifies_index=False, ) def is_unique(self: Self) -> Self: @@ -580,6 +629,7 @@ def func(_input: Any) -> Any: func, "is_unique", returns_scalar=False, + modifies_index=False, ) def is_in(self: Self, other: Any) -> Self: @@ -588,6 +638,7 @@ def is_in(self: Self, other: Any) -> Self: "is_in", other, returns_scalar=False, + modifies_index=False, ) def null_count(self: Self) -> Self: @@ -595,18 +646,9 @@ def null_count(self: Self) -> Self: lambda _input: _input.isna().sum(), "null_count", returns_scalar=True, + modifies_index=False, ) - def tail(self: Self) -> NoReturn: - # We can't (yet?) allow methods which modify the index - msg = "`Expr.tail` is not supported for the Dask backend. Please use `LazyFrame.tail` instead." - raise NotImplementedError(msg) - - def gather_every(self: Self, n: int, offset: int = 0) -> NoReturn: - # We can't (yet?) allow methods which modify the index - msg = "`Expr.gather_every` is not supported for the Dask backend. Please use `LazyFrame.gather_every` instead." - raise NotImplementedError(msg) - def over(self: Self, keys: list[str]) -> Self: def func(df: DaskLazyFrame) -> list[Any]: if self._output_names is None: @@ -631,6 +673,7 @@ def func(df: DaskLazyFrame) -> list[Any]: root_names=self._root_names, output_names=self._output_names, returns_scalar=False, + modifies_index=False, backend_version=self._backend_version, ) @@ -653,15 +696,42 @@ def func(_input: Any, ascending: bool, na_position: bool) -> Any: # noqa: FBT00 not descending, na_position, returns_scalar=False, + modifies_index=False, ) + # Index modifiers + def drop_nulls(self: Self) -> Self: return self._from_call( lambda _input: _input.dropna(), "drop_nulls", returns_scalar=False, + modifies_index=True, ) + def head(self: Self, n: int) -> Self: + return self._from_call( + lambda _input, _n: _input.head(_n, compute=False), + "head", + n, + returns_scalar=False, + modifies_index=True, + ) + + def tail(self: Self, n: int) -> Self: + return self._from_call( + lambda _input, _n: _input.tail(_n, compute=False), + "tail", + n, + returns_scalar=False, + modifies_index=True, + ) + + def gather_every(self: Self, n: int, offset: int = 0) -> NoReturn: + # We can't (yet?) allow methods which modify the index + msg = "`Expr.gather_every` is not supported for the Dask backend. Please use `LazyFrame.gather_every` instead." + raise NotImplementedError(msg) + @property def str(self: Self) -> DaskExprStringNamespace: return DaskExprStringNamespace(self) @@ -687,6 +757,7 @@ def func(_input: Any, dtype: DType | type[DType]) -> Any: "cast", dtype, returns_scalar=False, + modifies_index=False, ) @@ -712,6 +783,7 @@ def replace( literal, n, returns_scalar=False, + modifies_index=False, ) def replace_all( @@ -730,6 +802,7 @@ def replace_all( value, literal, returns_scalar=False, + modifies_index=False, ) def strip_chars(self, characters: str | None = None) -> DaskExpr: @@ -738,6 +811,7 @@ def strip_chars(self, characters: str | None = None) -> DaskExpr: "strip", characters, returns_scalar=False, + modifies_index=False, ) def starts_with(self, prefix: str) -> DaskExpr: @@ -746,6 +820,7 @@ def starts_with(self, prefix: str) -> DaskExpr: "starts_with", prefix, returns_scalar=False, + modifies_index=False, ) def ends_with(self, suffix: str) -> DaskExpr: @@ -754,6 +829,7 @@ def ends_with(self, suffix: str) -> DaskExpr: "ends_with", suffix, returns_scalar=False, + modifies_index=False, ) def contains(self, pattern: str, *, literal: bool = False) -> DaskExpr: @@ -763,6 +839,7 @@ def contains(self, pattern: str, *, literal: bool = False) -> DaskExpr: pattern, not literal, returns_scalar=False, + modifies_index=False, ) def slice(self, offset: int, length: int | None = None) -> DaskExpr: @@ -773,6 +850,7 @@ def slice(self, offset: int, length: int | None = None) -> DaskExpr: offset, stop, returns_scalar=False, + modifies_index=False, ) def to_datetime(self, format: str | None = None) -> DaskExpr: # noqa: A002 @@ -781,6 +859,7 @@ def to_datetime(self, format: str | None = None) -> DaskExpr: # noqa: A002 "to_datetime", format, returns_scalar=False, + modifies_index=False, ) def to_uppercase(self) -> DaskExpr: @@ -788,6 +867,7 @@ def to_uppercase(self) -> DaskExpr: lambda _input: _input.str.upper(), "to_uppercase", returns_scalar=False, + modifies_index=False, ) def to_lowercase(self) -> DaskExpr: @@ -795,6 +875,7 @@ def to_lowercase(self) -> DaskExpr: lambda _input: _input.str.lower(), "to_lowercase", returns_scalar=False, + modifies_index=False, ) @@ -807,6 +888,7 @@ def date(self) -> DaskExpr: lambda _input: _input.dt.date, "date", returns_scalar=False, + modifies_index=False, ) def year(self) -> DaskExpr: @@ -814,6 +896,7 @@ def year(self) -> DaskExpr: lambda _input: _input.dt.year, "year", returns_scalar=False, + modifies_index=False, ) def month(self) -> DaskExpr: @@ -821,6 +904,7 @@ def month(self) -> DaskExpr: lambda _input: _input.dt.month, "month", returns_scalar=False, + modifies_index=False, ) def day(self) -> DaskExpr: @@ -828,6 +912,7 @@ def day(self) -> DaskExpr: lambda _input: _input.dt.day, "day", returns_scalar=False, + modifies_index=False, ) def hour(self) -> DaskExpr: @@ -835,6 +920,7 @@ def hour(self) -> DaskExpr: lambda _input: _input.dt.hour, "hour", returns_scalar=False, + modifies_index=False, ) def minute(self) -> DaskExpr: @@ -842,6 +928,7 @@ def minute(self) -> DaskExpr: lambda _input: _input.dt.minute, "minute", returns_scalar=False, + modifies_index=False, ) def second(self) -> DaskExpr: @@ -849,6 +936,7 @@ def second(self) -> DaskExpr: lambda _input: _input.dt.second, "second", returns_scalar=False, + modifies_index=False, ) def millisecond(self) -> DaskExpr: @@ -856,6 +944,7 @@ def millisecond(self) -> DaskExpr: lambda _input: _input.dt.microsecond // 1000, "millisecond", returns_scalar=False, + modifies_index=False, ) def microsecond(self) -> DaskExpr: @@ -863,6 +952,7 @@ def microsecond(self) -> DaskExpr: lambda _input: _input.dt.microsecond, "microsecond", returns_scalar=False, + modifies_index=False, ) def nanosecond(self) -> DaskExpr: @@ -870,6 +960,7 @@ def nanosecond(self) -> DaskExpr: lambda _input: _input.dt.microsecond * 1000 + _input.dt.nanosecond, "nanosecond", returns_scalar=False, + modifies_index=False, ) def ordinal_day(self) -> DaskExpr: @@ -877,6 +968,7 @@ def ordinal_day(self) -> DaskExpr: lambda _input: _input.dt.dayofyear, "ordinal_day", returns_scalar=False, + modifies_index=False, ) def to_string(self, format: str) -> DaskExpr: # noqa: A002 @@ -885,6 +977,7 @@ def to_string(self, format: str) -> DaskExpr: # noqa: A002 "strftime", format.replace("%.f", ".%f"), returns_scalar=False, + modifies_index=False, ) def total_minutes(self) -> DaskExpr: @@ -892,6 +985,7 @@ def total_minutes(self) -> DaskExpr: lambda _input: _input.dt.total_seconds() // 60, "total_minutes", returns_scalar=False, + modifies_index=False, ) def total_seconds(self) -> DaskExpr: @@ -899,6 +993,7 @@ def total_seconds(self) -> DaskExpr: lambda _input: _input.dt.total_seconds() // 1, "total_seconds", returns_scalar=False, + modifies_index=False, ) def total_milliseconds(self) -> DaskExpr: @@ -906,6 +1001,7 @@ def total_milliseconds(self) -> DaskExpr: lambda _input: _input.dt.total_seconds() * 1000 // 1, "total_milliseconds", returns_scalar=False, + modifies_index=False, ) def total_microseconds(self) -> DaskExpr: @@ -913,6 +1009,7 @@ def total_microseconds(self) -> DaskExpr: lambda _input: _input.dt.total_seconds() * 1_000_000 // 1, "total_microseconds", returns_scalar=False, + modifies_index=False, ) def total_nanoseconds(self) -> DaskExpr: @@ -920,6 +1017,7 @@ def total_nanoseconds(self) -> DaskExpr: lambda _input: _input.dt.total_seconds() * 1_000_000_000 // 1, "total_nanoseconds", returns_scalar=False, + modifies_index=False, ) @@ -948,6 +1046,7 @@ def keep(self: Self) -> DaskExpr: root_names=root_names, output_names=root_names, returns_scalar=self._expr._returns_scalar, + modifies_index=self._expr._modifies_index, backend_version=self._expr._backend_version, ) @@ -974,6 +1073,7 @@ def map(self: Self, function: Callable[[str], str]) -> DaskExpr: root_names=root_names, output_names=output_names, returns_scalar=self._expr._returns_scalar, + modifies_index=self._expr._modifies_index, backend_version=self._expr._backend_version, ) @@ -998,6 +1098,7 @@ def prefix(self: Self, prefix: str) -> DaskExpr: root_names=root_names, output_names=output_names, returns_scalar=self._expr._returns_scalar, + modifies_index=self._expr._modifies_index, backend_version=self._expr._backend_version, ) @@ -1023,6 +1124,7 @@ def suffix(self: Self, suffix: str) -> DaskExpr: root_names=root_names, output_names=output_names, returns_scalar=self._expr._returns_scalar, + modifies_index=self._expr._modifies_index, backend_version=self._expr._backend_version, ) @@ -1048,6 +1150,7 @@ def to_lowercase(self: Self) -> DaskExpr: root_names=root_names, output_names=output_names, returns_scalar=self._expr._returns_scalar, + modifies_index=self._expr._modifies_index, backend_version=self._expr._backend_version, ) @@ -1073,5 +1176,6 @@ def to_uppercase(self: Self) -> DaskExpr: root_names=root_names, output_names=output_names, returns_scalar=self._expr._returns_scalar, + modifies_index=self._expr._modifies_index, backend_version=self._expr._backend_version, ) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index e6019b509..f1d310478 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -59,6 +59,7 @@ def func(df: DaskLazyFrame) -> list[Any]: root_names=None, output_names=None, returns_scalar=False, + modifies_index=False, backend_version=self._backend_version, ) @@ -78,6 +79,7 @@ def lit(self, value: Any, dtype: dtypes.DType | None) -> DaskExpr: root_names=None, output_names=["lit"], returns_scalar=False, + modifies_index=False, backend_version=self._backend_version, ) @@ -127,6 +129,7 @@ def func(df: DaskLazyFrame) -> list[Any]: root_names=None, output_names=["len"], returns_scalar=True, + modifies_index=False, backend_version=self._backend_version, ) @@ -187,7 +190,9 @@ def when( msg = "at least one predicate needs to be provided" raise TypeError(msg) - return DaskWhen(condition, self._backend_version, returns_scalar=False) + return DaskWhen( + condition, self._backend_version, returns_scalar=False, modifies_index=False + ) class DaskWhen: @@ -199,12 +204,14 @@ def __init__( otherwise_value: Any = None, *, returns_scalar: bool, + modifies_index: bool, ) -> None: self._backend_version = backend_version self._condition = condition self._then_value = then_value self._otherwise_value = otherwise_value self._returns_scalar = returns_scalar + self._modifies_index = modifies_index def __call__(self, df: DaskLazyFrame) -> list[Any]: from narwhals._dask.namespace import DaskNamespace @@ -247,6 +254,7 @@ def then(self, value: DaskExpr | Any) -> DaskThen: output_names=None, returns_scalar=self._returns_scalar, backend_version=self._backend_version, + modifies_index=self._modifies_index, ) @@ -260,6 +268,7 @@ def __init__( root_names: list[str] | None, output_names: list[str] | None, returns_scalar: bool, + modifies_index: bool, backend_version: tuple[int, ...], ) -> None: self._backend_version = backend_version @@ -270,6 +279,7 @@ def __init__( self._root_names = root_names self._output_names = output_names self._returns_scalar = returns_scalar + self._modifies_index = modifies_index def otherwise(self, value: DaskExpr | Any) -> DaskExpr: # type ignore because we are setting the `_call` attribute to a diff --git a/narwhals/_dask/selectors.py b/narwhals/_dask/selectors.py index 073b3abd8..ae6a23d38 100644 --- a/narwhals/_dask/selectors.py +++ b/narwhals/_dask/selectors.py @@ -32,6 +32,7 @@ def func(df: DaskLazyFrame) -> list[Any]: output_names=None, backend_version=self._backend_version, returns_scalar=False, + modifies_index=False, ) def numeric(self: Self) -> DaskSelector: @@ -71,6 +72,7 @@ def func(df: DaskLazyFrame) -> list[Any]: output_names=None, backend_version=self._backend_version, returns_scalar=False, + modifies_index=False, ) @@ -92,6 +94,7 @@ def _to_expr(self: Self) -> DaskExpr: root_names=self._root_names, output_names=self._output_names, backend_version=self._backend_version, + modifies_index=self._modifies_index, returns_scalar=self._returns_scalar, ) @@ -110,6 +113,7 @@ def call(df: DaskLazyFrame) -> list[Any]: root_names=None, output_names=None, backend_version=self._backend_version, + modifies_index=self._modifies_index, returns_scalar=self._returns_scalar, ) else: @@ -132,6 +136,7 @@ def call(df: DaskLazyFrame) -> list[Any]: root_names=None, output_names=None, backend_version=self._backend_version, + modifies_index=self._modifies_index, returns_scalar=self._returns_scalar, ) else: @@ -152,6 +157,7 @@ def call(df: DaskLazyFrame) -> list[Any]: root_names=None, output_names=None, backend_version=self._backend_version, + modifies_index=self._modifies_index, returns_scalar=self._returns_scalar, ) else: diff --git a/tests/expr_and_series/head_test.py b/tests/expr_and_series/head_test.py index ef2ed1bf1..a27856d79 100644 --- a/tests/expr_and_series/head_test.py +++ b/tests/expr_and_series/head_test.py @@ -10,8 +10,6 @@ @pytest.mark.parametrize("n", [2, -1]) def test_head(constructor: Any, n: int, request: Any) -> None: - if "dask" in str(constructor): - request.applymarker(pytest.mark.xfail) if "polars" in str(constructor) and n < 0: request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [1, 2, 3]})) diff --git a/tests/expr_and_series/tail_test.py b/tests/expr_and_series/tail_test.py index be17ffb4e..2b0a8af95 100644 --- a/tests/expr_and_series/tail_test.py +++ b/tests/expr_and_series/tail_test.py @@ -10,8 +10,6 @@ @pytest.mark.parametrize("n", [2, -1]) def test_head(constructor: Any, n: int, request: Any) -> None: - if "dask" in str(constructor): - request.applymarker(pytest.mark.xfail) if "polars" in str(constructor) and n < 0: request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [1, 2, 3]})) diff --git a/tests/frame/with_columns_test.py b/tests/frame/with_columns_test.py index 864e689e8..303137a08 100644 --- a/tests/frame/with_columns_test.py +++ b/tests/frame/with_columns_test.py @@ -2,6 +2,7 @@ import numpy as np import pandas as pd +import pytest import narwhals.stable.v1 as nw from tests.utils import compare_dicts @@ -41,3 +42,18 @@ def test_with_columns_order_single_row(constructor: Any) -> None: assert result.collect_schema().names() == ["a", "b", "z", "d"] expected = {"a": [2], "b": [4], "z": [7.0], "d": [0]} compare_dicts(result, expected) + + +def test_dask_with_columns_modifies_index() -> None: + pytest.importorskip("dask") + pytest.importorskip("dask_expr", exc_type=ImportError) + import dask.dataframe as dd + + data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9], "i": [0, 1, 2]} + + df = nw.from_native(dd.from_dict(data, npartitions=2)) + + with pytest.raises( + ValueError, match="Expressions that modify the index are not supported" + ): + df.with_columns(nw.col("b").head(1)) From d4ca60abeb683567ec855bdbcc5a425b5a46f829 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Tue, 27 Aug 2024 10:25:51 +0200 Subject: [PATCH 05/10] deal with reductions --- narwhals/_dask/dataframe.py | 8 +++----- narwhals/_dask/expr.py | 3 ++- tests/frame/select_test.py | 26 ++++++++++++++++++++++++++ tests/frame/with_columns_test.py | 7 ++----- 4 files changed, 33 insertions(+), 11 deletions(-) diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index cf0064d75..c0666d479 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -110,11 +110,9 @@ def select( # This is a simple slice => fastpath! return self._from_native_frame(self._native_frame.loc[:, exprs]) - n_modifies_index = sum( - getattr(e, "_modifies_index", 0) - for e in list(exprs) + list(named_exprs.values()) - ) - if n_modifies_index > 1: + all_exprs = list(exprs) + list(named_exprs.values()) + n_modifies_index = sum(getattr(e, "_modifies_index", 0) for e in all_exprs) + if len(all_exprs) > 1 and n_modifies_index > 1: msg = "Found multiple expressions that modify the index" raise ValueError(msg) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 25c72607c..8266add7c 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -131,7 +131,8 @@ def func(df: DaskLazyFrame) -> list[Any]: root_names=root_names, output_names=output_names, returns_scalar=self._returns_scalar or returns_scalar, - modifies_index=self._modifies_index or modifies_index, + modifies_index=(self._modifies_index or modifies_index) + and not (self._returns_scalar or returns_scalar), backend_version=self._backend_version, ) diff --git a/tests/frame/select_test.py b/tests/frame/select_test.py index 450e91066..e0fe496ce 100644 --- a/tests/frame/select_test.py +++ b/tests/frame/select_test.py @@ -31,3 +31,29 @@ def test_non_string_select_invalid() -> None: df = nw.from_native(pd.DataFrame({0: [1, 2], "b": [3, 4]})) with pytest.raises(TypeError, match="\n\nHint: if you were trying to select"): nw.to_native(df.select(0)) # type: ignore[arg-type] + + +def test_dask_select_reduction_and_modify_index() -> None: + pytest.importorskip("dask") + pytest.importorskip("dask_expr", exc_type=ImportError) + import dask.dataframe as dd + + data = {"a": [1, 3, 2], "b": [4, 4.0, 6], "z": [7.0, 8, 9]} + df = nw.from_native(dd.from_dict(data, npartitions=1)) + + result = df.select( + nw.col("a").head(2).sum(), + nw.col("b").tail(2).mean(), + nw.col("z").head(2), + ) + expected = {"a": [4, 4], "b": [5, 5], "z": [7.0, 8]} + compare_dicts(result, expected) + + # all reductions + result = df.select( + nw.col("a").head(2).sum(), + nw.col("b").tail(2).mean(), + nw.col("z").max(), + ) + expected = {"a": [4], "b": [5], "z": [9]} + compare_dicts(result, expected) diff --git a/tests/frame/with_columns_test.py b/tests/frame/with_columns_test.py index 303137a08..72160c27c 100644 --- a/tests/frame/with_columns_test.py +++ b/tests/frame/with_columns_test.py @@ -49,11 +49,8 @@ def test_dask_with_columns_modifies_index() -> None: pytest.importorskip("dask_expr", exc_type=ImportError) import dask.dataframe as dd - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9], "i": [0, 1, 2]} - - df = nw.from_native(dd.from_dict(data, npartitions=2)) - + df = nw.from_native(dd.from_dict({"a": [1, 3, 2]}, npartitions=2)) with pytest.raises( ValueError, match="Expressions that modify the index are not supported" ): - df.with_columns(nw.col("b").head(1)) + df.with_columns(nw.col("a").head(1)) From 941a675dffa1a238f60209ace974f1ec41dec76b Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Fri, 30 Aug 2024 17:16:34 +0200 Subject: [PATCH 06/10] head with npartitions --- narwhals/_dask/expr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 8266add7c..02afedc37 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -712,7 +712,7 @@ def drop_nulls(self: Self) -> Self: def head(self: Self, n: int) -> Self: return self._from_call( - lambda _input, _n: _input.head(_n, compute=False), + lambda _input, _n: _input.head(_n, npartitions=-1, compute=False), "head", n, returns_scalar=False, From 994bb07c8d577aae1f011ed0e06a8ff31181cf74 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Mon, 16 Sep 2024 13:19:14 +0200 Subject: [PATCH 07/10] almost there with tests --- tests/expr_and_series/sort_test.py | 44 ++++++++++++------------------ 1 file changed, 17 insertions(+), 27 deletions(-) diff --git a/tests/expr_and_series/sort_test.py b/tests/expr_and_series/sort_test.py index 7a977369a..e6ffe5fe9 100644 --- a/tests/expr_and_series/sort_test.py +++ b/tests/expr_and_series/sort_test.py @@ -3,6 +3,7 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import compare_dicts data = {"a": [0, 0, 2, -1], "b": [1, 3, 2, None]} @@ -10,34 +11,27 @@ @pytest.mark.parametrize( ("descending", "nulls_last", "expected"), [ - (True, True, {"b": [3, 2, 1, None]}), - (True, False, {"b": [None, 3, 2, 1]}), - (False, True, {"b": [1, 2, 3, None]}), - (False, False, {"b": [None, 1, 2, 3]}), + (True, True, {"b": [3, 2, 1, float("nan")]}), + (True, False, {"b": [float("nan"), 3, 2, 1]}), + (False, True, {"b": [1, 2, 3, float("nan")]}), + (False, False, {"b": [float("nan"), 1, 2, 3]}), ], ) def test_sort_single_expr( constructor: Any, descending: Any, nulls_last: Any, expected: Any ) -> None: - df = nw.from_native(constructor(data)).lazy() - result = nw.to_native( - df.select( - nw.col("b").sort(descending=descending, nulls_last=nulls_last), - ).collect() - ) - - expected_df = nw.to_native(nw.from_native(constructor(expected)).lazy().collect()) - result = nw.maybe_align_index(result, expected_df) - assert result.equals(expected_df) + df = nw.from_native(constructor(data)) + result = df.select(nw.col("b").sort(descending=descending, nulls_last=nulls_last)) + compare_dicts(result, expected) @pytest.mark.parametrize( ("descending", "nulls_last", "expected"), [ - (True, True, {"a": [0, 0, 2, -1], "b": [3, 2, 1, None]}), - (True, False, {"a": [0, 0, 2, -1], "b": [None, 3, 2, 1]}), - (False, True, {"a": [0, 0, 2, -1], "b": [1, 2, 3, None]}), - (False, False, {"a": [0, 0, 2, -1], "b": [None, 1, 2, 3]}), + (True, True, {"a": [0, 0, 2, -1], "b": [3, 2, 1, float("nan")]}), + (True, False, {"a": [0, 0, 2, -1], "b": [float("nan"), 3, 2, 1]}), + (False, True, {"a": [0, 0, 2, -1], "b": [1, 2, 3, float("nan")]}), + (False, False, {"a": [0, 0, 2, -1], "b": [float("nan"), 1, 2, 3]}), ], ) def test_sort_multiple_expr( @@ -46,16 +40,12 @@ def test_sort_multiple_expr( if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) - df = nw.from_native(constructor(data)).lazy() - result = nw.to_native( - df.select( - "a", - nw.col("b").sort(descending=descending, nulls_last=nulls_last), - ).collect() + df = nw.from_native(constructor(data)) + result = df.select( + "a", + nw.col("b").sort(descending=descending, nulls_last=nulls_last), ) - - expected_df = nw.to_native(nw.from_native(constructor(expected)).lazy().collect()) - assert result.equals(expected_df) + compare_dicts(result, expected) @pytest.mark.parametrize( From 47be8862e2d36478360f62fcaa9e1f9a86d687d2 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Mon, 16 Sep 2024 14:46:47 +0200 Subject: [PATCH 08/10] remaining methods --- narwhals/_dask/dataframe.py | 11 +- narwhals/_dask/expr.py | 111 ++++++++++++------ tests/expr_and_series/arg_true_test.py | 6 +- .../cat/get_categories_test.py | 7 +- tests/expr_and_series/filter_test.py | 15 ++- tests/expr_and_series/len_test.py | 6 +- tests/expr_and_series/mode_test.py | 7 +- tests/expr_and_series/sort_test.py | 34 ++++-- tpch/queries/q20.py | 6 +- tpch/queries/q22.py | 6 +- 10 files changed, 131 insertions(+), 78 deletions(-) diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 00de180d9..672e38d39 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -139,12 +139,17 @@ def select( col_order = list(new_series.keys()) - left_most_series = next( # pragma: no cover - s for s in new_series.values() if not isinstance(s, de._collection.Scalar) + left_most_name, left_most_series = next( # pragma: no cover + (name, s) + for name, s in new_series.items() + if not isinstance(s, de._collection.Scalar) ) + new_series.pop(left_most_name) return self._from_native_frame( - left_most_series.to_frame().assign(**new_series).loc[:, col_order] + left_most_series.to_frame(name=left_most_name) + .assign(**new_series) + .loc[:, col_order] ) def drop_nulls(self: Self, subset: str | list[str] | None) -> Self: diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 9f1ac553d..c2a5b168d 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -688,31 +688,55 @@ def func(df: DaskLazyFrame) -> list[Any]: backend_version=self._backend_version, ) - def mode(self: Self) -> Self: - msg = "`Expr.mode` is not supported for the Dask backend." - raise NotImplementedError(msg) + def cast( + self: Self, + dtype: DType | type[DType], + ) -> Self: + def func(_input: Any, dtype: DType | type[DType]) -> Any: + dtype = reverse_translate_dtype(dtype) + return _input.astype(dtype) + + return self._from_call( + func, + "cast", + dtype, + returns_scalar=False, + modifies_index=False, + ) + + # Index modifiers def sort(self: Self, *, descending: bool = False, nulls_last: bool = False) -> Self: - na_position = "last" if nulls_last else "first" + msg = "`Expr.sort` is not supported for the Dask backend. Please use `LazyFrame.sort` instead." + raise NotImplementedError(msg) - def func(_input: Any, ascending: bool, na_position: bool) -> Any: # noqa: FBT001 - name = _input.name + def gather_every(self: Self, n: int, offset: int = 0) -> NoReturn: + msg = "`Expr.gather_every` is not supported for the Dask backend. Please use `LazyFrame.gather_every` instead." + raise NotImplementedError(msg) + + def sample( + self: Self, + n: int | None = None, + *, + fraction: float | None = None, + with_replacement: bool = False, + seed: int | None = None, + ) -> NoReturn: + msg = "`Expr.sample` is not supported for the Dask backend." + raise NotImplementedError(msg) - return _input.to_frame(name=name).sort_values( - by=name, ascending=ascending, na_position=na_position - )[name] + def mode(self: Self) -> Self: + def func(_input: Any) -> Any: + name = _input.name + return _input.to_frame(name=name).mode()[name] return self._from_call( func, - "sort", - not descending, - na_position, + "mode", returns_scalar=False, modifies_index=True, ) - # Index modifiers - def drop_nulls(self: Self) -> Self: return self._from_call( lambda _input: _input.dropna(), @@ -753,10 +777,45 @@ def unique(self: Self) -> Self: modifies_index=True, ) - def gather_every(self: Self, n: int, offset: int = 0) -> NoReturn: - # We can't (yet?) allow methods which modify the index - msg = "`Expr.gather_every` is not supported for the Dask backend. Please use `LazyFrame.gather_every` instead." - raise NotImplementedError(msg) + def filter(self: Self, *predicates: Any) -> Self: + plx = self.__narwhals_namespace__() + expr = plx.all_horizontal(*predicates) + + def func(df: DaskLazyFrame) -> list[Any]: + if self._output_names is None: + msg = ( + "Anonymous expressions are not supported in filter.\n" + "Instead of `nw.all()`, try using a named expression, such as " + "`nw.col('a', 'b')`\n" + ) + raise ValueError(msg) + mask = expr._call(df)[0] + return [df._native_frame[name].loc[mask] for name in self._output_names] + + return self.__class__( + func, + depth=self._depth + 1, + function_name=self._function_name + "->filter", + root_names=self._root_names, + output_names=self._output_names, + returns_scalar=False, + modifies_index=True, + backend_version=self._backend_version, + ) + + def arg_true(self: Self) -> Self: + def func(_input: dask_expr.Series) -> dask_expr.Series: + name = _input.name + return add_row_index(_input.to_frame(name=name), name).loc[_input, name] + + return self._from_call( + func, + "arg_true", + returns_scalar=False, + modifies_index=True, + ) + + # Namespaces @property def str(self: Self) -> DaskExprStringNamespace: @@ -770,22 +829,6 @@ def dt(self: Self) -> DaskExprDateTimeNamespace: def name(self: Self) -> DaskExprNameNamespace: return DaskExprNameNamespace(self) - def cast( - self: Self, - dtype: DType | type[DType], - ) -> Self: - def func(_input: Any, dtype: DType | type[DType]) -> Any: - dtype = reverse_translate_dtype(dtype) - return _input.astype(dtype) - - return self._from_call( - func, - "cast", - dtype, - returns_scalar=False, - modifies_index=False, - ) - class DaskExprStringNamespace: def __init__(self, expr: DaskExpr) -> None: diff --git a/tests/expr_and_series/arg_true_test.py b/tests/expr_and_series/arg_true_test.py index 7e1262aa8..9bcfee0c1 100644 --- a/tests/expr_and_series/arg_true_test.py +++ b/tests/expr_and_series/arg_true_test.py @@ -1,15 +1,11 @@ from typing import Any -import pytest - import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import compare_dicts -def test_arg_true(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "dask" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_arg_true(constructor: Constructor) -> None: df = nw.from_native(constructor({"a": [1, None, None, 3]})) result = df.select(nw.col("a").is_null().arg_true()) expected = {"a": [1, 2]} diff --git a/tests/expr_and_series/cat/get_categories_test.py b/tests/expr_and_series/cat/get_categories_test.py index 122f3c83e..a5b093d90 100644 --- a/tests/expr_and_series/cat/get_categories_test.py +++ b/tests/expr_and_series/cat/get_categories_test.py @@ -1,18 +1,19 @@ from __future__ import annotations -from typing import Any - import pyarrow as pa import pytest import narwhals.stable.v1 as nw from narwhals.utils import parse_version +from tests.utils import Constructor from tests.utils import compare_dicts data = {"a": ["one", "two", "two"]} -def test_get_categories(request: pytest.FixtureRequest, constructor_eager: Any) -> None: +def test_get_categories( + request: pytest.FixtureRequest, constructor_eager: Constructor +) -> None: if "pyarrow_table" in str(constructor_eager) and parse_version( pa.__version__ ) < parse_version("15.0.0"): diff --git a/tests/expr_and_series/filter_test.py b/tests/expr_and_series/filter_test.py index 80267d1d0..7db9b8e82 100644 --- a/tests/expr_and_series/filter_test.py +++ b/tests/expr_and_series/filter_test.py @@ -14,15 +14,24 @@ } -def test_filter(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "dask" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_filter_single_expr(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.col("a").filter(nw.col("i") < 2, nw.col("c") == 5)) expected = {"a": [0]} compare_dicts(result, expected) +def test_filter_multi_expr( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "dask" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) + result = df.select(nw.col("a").filter(nw.col("i") < 2, nw.col("c") == 5), nw.col("b")) + expected = {"a": [0] * 5, "b": [1, 2, 3, 5, 3]} + compare_dicts(result, expected) + + def test_filter_series(constructor_eager: Any) -> None: df = nw.from_native(constructor_eager(data), eager_only=True) result = df.select(df["a"].filter((df["i"] < 2) & (df["c"] == 5))) diff --git a/tests/expr_and_series/len_test.py b/tests/expr_and_series/len_test.py index b1e1674bf..ceefa3a1e 100644 --- a/tests/expr_and_series/len_test.py +++ b/tests/expr_and_series/len_test.py @@ -1,7 +1,5 @@ from typing import Any -import pytest - import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import compare_dicts @@ -18,11 +16,9 @@ def test_len_no_filter(constructor: Constructor) -> None: compare_dicts(df, expected) -def test_len_chaining(constructor: Constructor, request: pytest.FixtureRequest) -> None: +def test_len_chaining(constructor: Constructor) -> None: data = {"a": list("xyz"), "b": [1, 2, 1]} expected = {"a1": [2], "a2": [1]} - if "dask" in str(constructor): - request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)).select( nw.col("a").filter(nw.col("b") == 1).len().alias("a1"), nw.col("a").filter(nw.col("b") == 2).len().alias("a2"), diff --git a/tests/expr_and_series/mode_test.py b/tests/expr_and_series/mode_test.py index 8e39405af..a47333d8d 100644 --- a/tests/expr_and_series/mode_test.py +++ b/tests/expr_and_series/mode_test.py @@ -14,12 +14,7 @@ } -def test_mode_single_expr( - constructor: Constructor, request: pytest.FixtureRequest -) -> None: - if "dask" in str(constructor): - request.applymarker(pytest.mark.xfail) - +def test_mode_single_expr(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.col("a").mode()).sort("a") expected = {"a": [1, 2]} diff --git a/tests/expr_and_series/sort_test.py b/tests/expr_and_series/sort_test.py index e6ffe5fe9..e5b1d493e 100644 --- a/tests/expr_and_series/sort_test.py +++ b/tests/expr_and_series/sort_test.py @@ -1,8 +1,9 @@ -from typing import Any +from __future__ import annotations import pytest import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import compare_dicts data = {"a": [0, 0, 2, -1], "b": [1, 3, 2, None]} @@ -18,8 +19,14 @@ ], ) def test_sort_single_expr( - constructor: Any, descending: Any, nulls_last: Any, expected: Any + constructor: Constructor, + descending: bool, # noqa: FBT001 + nulls_last: bool, # noqa: FBT001 + expected: dict[str, float], + request: pytest.FixtureRequest, ) -> None: + if "dask" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(nw.col("b").sort(descending=descending, nulls_last=nulls_last)) compare_dicts(result, expected) @@ -35,7 +42,11 @@ def test_sort_single_expr( ], ) def test_sort_multiple_expr( - constructor: Any, descending: Any, nulls_last: Any, expected: Any, request: Any + constructor: Constructor, + descending: bool, # noqa: FBT001 + nulls_last: bool, # noqa: FBT001 + expected: dict[str, float], + request: pytest.FixtureRequest, ) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) @@ -51,17 +62,18 @@ def test_sort_multiple_expr( @pytest.mark.parametrize( ("descending", "nulls_last", "expected"), [ - (True, True, [3, 2, 1, None]), - (True, False, [None, 3, 2, 1]), - (False, True, [1, 2, 3, None]), - (False, False, [None, 1, 2, 3]), + (True, True, [3, 2, 1, float("nan")]), + (True, False, [float("nan"), 3, 2, 1]), + (False, True, [1, 2, 3, float("nan")]), + (False, False, [float("nan"), 1, 2, 3]), ], ) def test_sort_series( - constructor_eager: Any, descending: Any, nulls_last: Any, expected: Any + constructor_eager: Constructor, + descending: bool, # noqa: FBT001 + nulls_last: bool, # noqa: FBT001 + expected: dict[str, float], ) -> None: series = nw.from_native(constructor_eager(data), eager_only=True)["b"] result = series.sort(descending=descending, nulls_last=nulls_last) - assert ( - result == nw.from_native(constructor_eager({"a": expected}), eager_only=True)["a"] - ) + compare_dicts({"b": result}, {"b": expected}) diff --git a/tpch/queries/q20.py b/tpch/queries/q20.py index b0dabb29e..d9014f7b8 100644 --- a/tpch/queries/q20.py +++ b/tpch/queries/q20.py @@ -28,8 +28,7 @@ def query( return ( part_ds.filter(nw.col("p_name").str.starts_with(var4)) - .select("p_partkey") - .unique("p_partkey") + .select(nw.col("p_partkey").unique()) .join(partsupp_ds, left_on="p_partkey", right_on="ps_partkey") .join( query1, @@ -37,8 +36,7 @@ def query( right_on=["l_suppkey", "l_partkey"], ) .filter(nw.col("ps_availqty") > nw.col("sum_quantity")) - .select("ps_suppkey") - .unique("ps_suppkey") + .select(nw.col("ps_suppkey").unique()) .join(query3, left_on="ps_suppkey", right_on="s_suppkey") .select("s_name", "s_address") .sort("s_name") diff --git a/tpch/queries/q22.py b/tpch/queries/q22.py index 2e0973227..4738c6fd3 100644 --- a/tpch/queries/q22.py +++ b/tpch/queries/q22.py @@ -14,10 +14,8 @@ def query(customer_ds: FrameT, orders_ds: FrameT) -> FrameT: nw.col("c_acctbal").mean().alias("avg_acctbal") ) - q3 = ( - orders_ds.select("o_custkey") - .unique("o_custkey") - .with_columns(nw.col("o_custkey").alias("c_custkey")) + q3 = orders_ds.select(nw.col("o_custkey").unique()).with_columns( + nw.col("o_custkey").alias("c_custkey") ) return ( From 44cb75456800388a40fe2a165e415023a138795a Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Mon, 16 Sep 2024 14:58:09 +0200 Subject: [PATCH 09/10] polars regression --- tests/expr_and_series/filter_test.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/tests/expr_and_series/filter_test.py b/tests/expr_and_series/filter_test.py index 7db9b8e82..fed9c6f45 100644 --- a/tests/expr_and_series/filter_test.py +++ b/tests/expr_and_series/filter_test.py @@ -1,7 +1,5 @@ from typing import Any -import pytest - import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import compare_dicts @@ -14,24 +12,13 @@ } -def test_filter_single_expr(constructor: Constructor) -> None: +def test_filter_expr(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.col("a").filter(nw.col("i") < 2, nw.col("c") == 5)) expected = {"a": [0]} compare_dicts(result, expected) -def test_filter_multi_expr( - constructor: Constructor, request: pytest.FixtureRequest -) -> None: - if "dask" in str(constructor): - request.applymarker(pytest.mark.xfail) - df = nw.from_native(constructor(data)) - result = df.select(nw.col("a").filter(nw.col("i") < 2, nw.col("c") == 5), nw.col("b")) - expected = {"a": [0] * 5, "b": [1, 2, 3, 5, 3]} - compare_dicts(result, expected) - - def test_filter_series(constructor_eager: Any) -> None: df = nw.from_native(constructor_eager(data), eager_only=True) result = df.select(df["a"].filter((df["i"] < 2) & (df["c"] == 5))) From 63264469dd79736215c8ada0aa343046bb5ec092 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Mon, 16 Sep 2024 15:11:16 +0200 Subject: [PATCH 10/10] no cover anonymous expr in filter --- narwhals/_dask/expr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index c2a5b168d..d4e1e623a 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -782,7 +782,7 @@ def filter(self: Self, *predicates: Any) -> Self: expr = plx.all_horizontal(*predicates) def func(df: DaskLazyFrame) -> list[Any]: - if self._output_names is None: + if self._output_names is None: # pragma: no cover msg = ( "Anonymous expressions are not supported in filter.\n" "Instead of `nw.all()`, try using a named expression, such as "