Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: "carefully" allow for dask Expr that modify index #743

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
27 changes: 24 additions & 3 deletions narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from typing import Sequence

from narwhals._dask.utils import parse_exprs_and_named_exprs
from narwhals._dask.utils import set_axis
from narwhals._pandas_like.utils import translate_dtype
from narwhals.dependencies import get_dask_dataframe
from narwhals.dependencies import get_dask_expr
from narwhals.dependencies import get_pandas
from narwhals.utils import Implementation
from narwhals.utils import flatten
Expand Down Expand Up @@ -47,8 +49,9 @@ def _from_native_dataframe(self, df: Any) -> Self:

def with_columns(self, *exprs: DaskExpr, **named_exprs: DaskExpr) -> Self:
df = self._native_dataframe
index = df.index
new_series = parse_exprs_and_named_exprs(self, *exprs, **named_exprs)
df = df.assign(**new_series)
df = df.assign(**{k: set_axis(v, index) for k, v in new_series.items()})
return self._from_native_dataframe(df)

def collect(self) -> Any:
Expand Down Expand Up @@ -106,8 +109,26 @@ def select(
)
return self._from_native_dataframe(df)

df = self._native_dataframe.assign(**new_series).loc[:, list(new_series.keys())]
return self._from_native_dataframe(df)
pd = get_pandas()
de = get_dask_expr()

col_order = list(new_series.keys())

index = next( # pragma: no cover
s for s in new_series.values() if not isinstance(s, de._collection.Scalar)
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved
).index

new_series = {
k: set_axis(v, index)
for k, v in sorted(
new_series.items(),
key=lambda item: isinstance(item[1], de._collection.Scalar),
)
}

return self._from_native_dataframe(
dd.from_pandas(pd.DataFrame()).assign(**new_series).loc[:, col_order]
)

def drop_nulls(self) -> Self:
return self._from_native_dataframe(self._native_dataframe.dropna())
Expand Down
25 changes: 25 additions & 0 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,31 @@ def clip(
returns_scalar=False,
)

def sort(self: Self, *, descending: bool = False, nulls_last: bool = False) -> Self:
na_position = "last" if nulls_last else "first"

def func(_input: Any, ascending: bool, na_position: bool) -> Any: # noqa: FBT001
name = _input.name

return _input.to_frame(name=name).sort_values(
by=name, ascending=ascending, na_position=na_position
)[name]

return self._from_call(
func,
"sort",
not descending,
na_position,
returns_scalar=False,
)

def drop_nulls(self: Self) -> Self:
return self._from_call(
lambda _input: _input.dropna(),
"drop_nulls",
returns_scalar=False,
)

@property
def str(self: Self) -> DaskExprStringNamespace:
return DaskExprStringNamespace(self)
Expand Down
15 changes: 15 additions & 0 deletions narwhals/_dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@

from typing import TYPE_CHECKING
from typing import Any
from typing import TypeVar

from narwhals.dependencies import get_dask_expr

if TYPE_CHECKING:
from dask_expr._collection import Index
from dask_expr._collection import Scalar
from dask_expr._collection import Series

from narwhals._dask.dataframe import DaskLazyFrame

T = TypeVar("T", Scalar, Series)


def maybe_evaluate(df: DaskLazyFrame, obj: Any) -> Any:
from narwhals._dask.expr import DaskExpr
Expand Down Expand Up @@ -64,3 +71,11 @@ def parse_exprs_and_named_exprs(
raise AssertionError(msg)
results[name] = _results[0]
return results


def set_axis(obj: T, index: Index) -> T:
de = get_dask_expr()
if isinstance(obj, de._collection.Scalar):
return obj
else:
return de._expr.AssignIndex(obj, index)
72 changes: 40 additions & 32 deletions tests/expr_and_series/drop_nulls_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,52 +7,60 @@
import narwhals as nw
from tests.utils import compare_dicts

data = {
"a": [1, 2, None],
"b": [3, 4, 5],
"c": [None, None, None],
"d": [6, None, None],
}

def test_drop_nulls(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
data = {
"A": [1, 2, None, 4],
"B": [5, 6, 7, 8],
"C": [None, None, None, None],
"D": [9, 10, 11, 12],
}

def test_drop_nulls(constructor: Any) -> None:
df = nw.from_native(constructor(data))

result_a = df.select(nw.col("A").drop_nulls())
result_b = df.select(nw.col("B").drop_nulls())
result_c = df.select(nw.col("C").drop_nulls())
result_d = df.select(nw.col("D").drop_nulls())
expected_a = {"A": [1.0, 2.0, 4.0]}
expected_b = {"B": [5, 6, 7, 8]}
expected_c = {"C": []} # type: ignore[var-annotated]
expected_d = {"D": [9, 10, 11, 12]}
result_a = df.select(nw.col("a").drop_nulls())
result_b = df.select(nw.col("b").drop_nulls())
result_c = df.select(nw.col("c").drop_nulls())
result_d = df.select(nw.col("d").drop_nulls())

expected_a = {"a": [1.0, 2.0]}
expected_b = {"b": [3, 4, 5]}
expected_c = {"c": []} # type: ignore[var-annotated]
expected_d = {"d": [6]}

compare_dicts(result_a, expected_a)
compare_dicts(result_b, expected_b)
compare_dicts(result_c, expected_c)
compare_dicts(result_d, expected_d)


def test_drop_nulls_series(constructor_eager: Any) -> None:
data = {
"A": [1, 2, None, 4],
"B": [5, 6, 7, 8],
"C": [None, None, None, None],
"D": [9, 10, 11, 12],
}
def test_drop_nulls_broadcast(constructor: Any, request: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = df.select(nw.col("a").drop_nulls(), nw.col("d").drop_nulls())
expected = {"a": [1.0, 2.0], "d": [6, 6]}
Comment on lines +47 to +48
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sadly this broadcast is not working as drop_nulls does not return a scalar. I would consider this an edge case and focus on the broader support

compare_dicts(result, expected)


def test_drop_nulls_invalid(constructor: Any) -> None:
df = nw.from_native(constructor(data)).lazy()

with pytest.raises(Exception): # noqa: B017, PT011
df.select(nw.col("a").drop_nulls(), nw.col("b").drop_nulls()).collect()


def test_drop_nulls_series(constructor_eager: Any) -> None:
df = nw.from_native(constructor_eager(data), eager_only=True)

result_a = df.select(df["A"].drop_nulls())
result_b = df.select(df["B"].drop_nulls())
result_c = df.select(df["C"].drop_nulls())
result_d = df.select(df["D"].drop_nulls())
expected_a = {"A": [1.0, 2.0, 4.0]}
expected_b = {"B": [5, 6, 7, 8]}
expected_c = {"C": []} # type: ignore[var-annotated]
expected_d = {"D": [9, 10, 11, 12]}
result_a = df.select(df["a"].drop_nulls())
result_b = df.select(df["b"].drop_nulls())
result_c = df.select(df["c"].drop_nulls())
result_d = df.select(df["d"].drop_nulls())
expected_a = {"a": [1.0, 2.0]}
expected_b = {"b": [3, 4, 5]}
expected_c = {"c": []} # type: ignore[var-annotated]
expected_d = {"d": [6]}

compare_dicts(result_a, expected_a)
compare_dicts(result_b, expected_b)
Expand Down
10 changes: 6 additions & 4 deletions tests/expr_and_series/sort_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@
],
)
def test_sort_expr(
constructor_eager: Any, descending: Any, nulls_last: Any, expected: Any
constructor: Any, descending: Any, nulls_last: Any, expected: Any
) -> None:
df = nw.from_native(constructor_eager(data), eager_only=True)
df = nw.from_native(constructor(data)).lazy()
result = nw.to_native(
df.select(
"a",
nw.col("b").sort(descending=descending, nulls_last=nulls_last),
)
).collect()
)
assert result.equals(constructor_eager(expected))

expected_df = nw.to_native(nw.from_native(constructor(expected)).lazy().collect())
assert result.equals(expected_df)


@pytest.mark.parametrize(
Expand Down
Loading