Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: "carefully" allow for dask Expr that modify index #743

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
39 changes: 32 additions & 7 deletions narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,17 @@ def _from_native_frame(self, df: Any) -> Self:
)

def with_columns(self, *exprs: DaskExpr, **named_exprs: DaskExpr) -> Self:
df = self._native_frame
n_modifies_index = sum(
getattr(e, "_modifies_index", 0)
for e in list(exprs) + list(named_exprs.values())
)

if n_modifies_index > 0:
msg = "Expressions that modify the index are not supported in `with_columns`."
raise ValueError(msg)

new_series = parse_exprs_and_named_exprs(self, *exprs, **named_exprs)
df = df.assign(**new_series)
return self._from_native_frame(df)
return self._from_native_frame(self._native_frame.assign(**new_series))

def collect(self) -> Any:
import pandas as pd # ignore-banned-import()
Expand Down Expand Up @@ -113,17 +120,23 @@ def select(
**named_exprs: IntoDaskExpr,
) -> Self:
import dask.dataframe as dd # ignore-banned-import
import dask_expr as de # ignore-banned-import
import pandas as pd # ignore-banned-import

if exprs and all(isinstance(x, str) for x in exprs) and not named_exprs:
# This is a simple slice => fastpath!
return self._from_native_frame(self._native_frame.loc[:, exprs])

all_exprs = list(exprs) + list(named_exprs.values())
n_modifies_index = sum(getattr(e, "_modifies_index", 0) for e in all_exprs)
if len(all_exprs) > 1 and n_modifies_index > 1:
msg = "Found multiple expressions that modify the index"
raise ValueError(msg)

new_series = parse_exprs_and_named_exprs(self, *exprs, **named_exprs)

if not new_series:
# return empty dataframe, like Polars does
import pandas as pd # ignore-banned-import

return self._from_native_frame(
dd.from_pandas(pd.DataFrame(), npartitions=self._native_frame.npartitions)
)
Expand All @@ -136,8 +149,20 @@ def select(
)
return self._from_native_frame(df)

df = self._native_frame.assign(**new_series).loc[:, list(new_series.keys())]
return self._from_native_frame(df)
col_order = list(new_series.keys())

left_most_name, left_most_series = next( # pragma: no cover
(name, s)
for name, s in new_series.items()
if not isinstance(s, de._collection.Scalar)
)
new_series.pop(left_most_name)

return self._from_native_frame(
left_most_series.to_frame(name=left_most_name)
.assign(**new_series)
.loc[:, col_order]
)

def drop_nulls(self: Self, subset: str | list[str] | None) -> Self:
if subset is None:
Expand Down
Loading
Loading