Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: "carefully" allow for dask Expr that modify index #743

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
34 changes: 27 additions & 7 deletions narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,17 @@ def _from_native_frame(self, df: Any) -> Self:
return self.__class__(df, backend_version=self._backend_version)

def with_columns(self, *exprs: DaskExpr, **named_exprs: DaskExpr) -> Self:
df = self._native_frame
n_modifies_index = sum(
getattr(e, "_modifies_index", 0)
for e in list(exprs) + list(named_exprs.values())
)

if n_modifies_index > 0:
msg = "Expressions that modify the index are not supported in `with_columns`."
raise ValueError(msg)

new_series = parse_exprs_and_named_exprs(self, *exprs, **named_exprs)
df = df.assign(**new_series)
return self._from_native_frame(df)
return self._from_native_frame(self._native_frame.assign(**new_series))

def collect(self) -> Any:
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
Expand Down Expand Up @@ -96,17 +103,23 @@ def select(
**named_exprs: IntoDaskExpr,
) -> Self:
import dask.dataframe as dd # ignore-banned-import
import dask_expr as de # ignore-banned-import
import pandas as pd # ignore-banned-import

if exprs and all(isinstance(x, str) for x in exprs) and not named_exprs:
# This is a simple slice => fastpath!
return self._from_native_frame(self._native_frame.loc[:, exprs])

all_exprs = list(exprs) + list(named_exprs.values())
n_modifies_index = sum(getattr(e, "_modifies_index", 0) for e in all_exprs)
if len(all_exprs) > 1 and n_modifies_index > 1:
msg = "Found multiple expressions that modify the index"
raise ValueError(msg)

new_series = parse_exprs_and_named_exprs(self, *exprs, **named_exprs)

if not new_series:
# return empty dataframe, like Polars does
import pandas as pd # ignore-banned-import

return self._from_native_frame(
dd.from_pandas(pd.DataFrame(), npartitions=self._native_frame.npartitions)
)
Expand All @@ -119,8 +132,15 @@ def select(
)
return self._from_native_frame(df)

df = self._native_frame.assign(**new_series).loc[:, list(new_series.keys())]
return self._from_native_frame(df)
col_order = list(new_series.keys())

left_most_series = next( # pragma: no cover
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is guaranteed to not end up in StopIteration error as if everything was a scalar the previous block would have been entered and returned

s for s in new_series.values() if not isinstance(s, de._collection.Scalar)
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved
)

return self._from_native_frame(
left_most_series.to_frame().assign(**new_series).loc[:, col_order]
)

def drop_nulls(self: Self, subset: str | list[str] | None) -> Self:
if subset is None:
Expand Down
Loading
Loading