Skip to content

Commit

Permalink
chore: dask nightly (#1768)
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi authored Jan 9, 2025
1 parent 5dca2a9 commit 40a83e3
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 52 deletions.
48 changes: 25 additions & 23 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
from narwhals.utils import import_dtypes_module

if TYPE_CHECKING:
import dask_expr
try:
import dask.dataframe.dask_expr as dx
except ModuleNotFoundError:
import dask_expr as dx

from typing_extensions import Self

from narwhals._dask.dataframe import DaskLazyFrame
Expand All @@ -32,12 +36,12 @@
from narwhals.utils import Version


class DaskExpr(CompliantExpr["dask_expr.Series"]):
class DaskExpr(CompliantExpr["dx.Series"]):
_implementation: Implementation = Implementation.DASK

def __init__(
self,
call: Callable[[DaskLazyFrame], Sequence[dask_expr.Series]],
call: Callable[[DaskLazyFrame], Sequence[dx.Series]],
*,
depth: int,
function_name: str,
Expand All @@ -60,7 +64,7 @@ def __init__(
self._version = version
self._kwargs = kwargs

def __call__(self, df: DaskLazyFrame) -> Sequence[dask_expr.Series]:
def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
return self._call(df)

def __narwhals_expr__(self) -> None: ...
Expand All @@ -78,7 +82,7 @@ def from_column_names(
backend_version: tuple[int, ...],
version: Version,
) -> Self:
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
def func(df: DaskLazyFrame) -> list[dx.Series]:
try:
return [df._native_frame[column_name] for column_name in column_names]
except KeyError as e:
Expand Down Expand Up @@ -107,7 +111,7 @@ def from_column_indices(
backend_version: tuple[int, ...],
version: Version,
) -> Self:
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
def func(df: DaskLazyFrame) -> list[dx.Series]:
return [
df._native_frame.iloc[:, column_index] for column_index in column_indices
]
Expand All @@ -126,14 +130,14 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]:

def _from_call(
self,
# First argument to `call` should be `dask_expr.Series`
call: Callable[..., dask_expr.Series],
# First argument to `call` should be `dx.Series`
call: Callable[..., dx.Series],
expr_name: str,
*,
returns_scalar: bool,
**kwargs: Any,
) -> Self:
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
def func(df: DaskLazyFrame) -> list[dx.Series]:
results = []
inputs = self._call(df)
_kwargs = {key: maybe_evaluate(df, value) for key, value in kwargs.items()}
Expand Down Expand Up @@ -163,7 +167,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
)

def alias(self, name: str) -> Self:
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
def func(df: DaskLazyFrame) -> list[dx.Series]:
inputs = self._call(df)
return [_input.rename(name) for _input in inputs]

Expand Down Expand Up @@ -312,7 +316,7 @@ def mean(self) -> Self:
def median(self) -> Self:
from narwhals.exceptions import InvalidOperationError

def func(s: dask_expr.Series) -> dask_expr.Series:
def func(s: dx.Series) -> dx.Series:
dtype = native_to_narwhals_dtype(s, self._version, Implementation.DASK)
if not dtype.is_numeric():
msg = "`median` operation not supported for non-numeric input type."
Expand Down Expand Up @@ -511,11 +515,11 @@ def fill_null(
limit: int | None = None,
) -> DaskExpr:
def func(
_input: dask_expr.Series,
_input: dx.Series,
value: Any | None,
strategy: str | None,
limit: int | None,
) -> dask_expr.Series:
) -> dx.Series:
if value is not None:
res_ser = _input.fillna(value)
else:
Expand Down Expand Up @@ -566,7 +570,7 @@ def is_null(self: Self) -> Self:
)

def is_nan(self: Self) -> Self:
def func(_input: dask_expr.Series) -> dask_expr.Series:
def func(_input: dx.Series) -> dx.Series:
dtype = native_to_narwhals_dtype(_input, self._version, self._implementation)
if dtype.is_numeric():
return _input != _input # noqa: PLR0124
Expand All @@ -585,7 +589,7 @@ def quantile(
) -> Self:
if interpolation == "linear":

def func(_input: dask_expr.Series, quantile: float) -> dask_expr.Series:
def func(_input: dx.Series, quantile: float) -> dx.Series:
if _input.npartitions > 1:
msg = "`Expr.quantile` is not supported for Dask backend with multiple partitions."
raise NotImplementedError(msg)
Expand All @@ -599,7 +603,7 @@ def func(_input: dask_expr.Series, quantile: float) -> dask_expr.Series:
raise NotImplementedError(msg)

def is_first_distinct(self: Self) -> Self:
def func(_input: dask_expr.Series) -> dask_expr.Series:
def func(_input: dx.Series) -> dx.Series:
_name = _input.name
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
_input = add_row_index(
Expand All @@ -618,7 +622,7 @@ def func(_input: dask_expr.Series) -> dask_expr.Series:
)

def is_last_distinct(self: Self) -> Self:
def func(_input: dask_expr.Series) -> dask_expr.Series:
def func(_input: dx.Series) -> dx.Series:
_name = _input.name
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
_input = add_row_index(
Expand All @@ -635,7 +639,7 @@ def func(_input: dask_expr.Series) -> dask_expr.Series:
)

def is_duplicated(self: Self) -> Self:
def func(_input: dask_expr.Series) -> dask_expr.Series:
def func(_input: dx.Series) -> dx.Series:
_name = _input.name
return (
_input.to_frame()
Expand All @@ -647,7 +651,7 @@ def func(_input: dask_expr.Series) -> dask_expr.Series:
return self._from_call(func, "is_duplicated", returns_scalar=self._returns_scalar)

def is_unique(self: Self) -> Self:
def func(_input: dask_expr.Series) -> dask_expr.Series:
def func(_input: dx.Series) -> dx.Series:
_name = _input.name
return (
_input.to_frame()
Expand Down Expand Up @@ -967,7 +971,7 @@ def replace_time_zone(self, time_zone: str | None) -> DaskExpr:
)

def convert_time_zone(self, time_zone: str) -> DaskExpr:
def func(s: dask_expr.Series, time_zone: str) -> dask_expr.Series:
def func(s: dx.Series, time_zone: str) -> dx.Series:
dtype = native_to_narwhals_dtype(
s, self._compliant_expr._version, Implementation.DASK
)
Expand All @@ -984,9 +988,7 @@ def func(s: dask_expr.Series, time_zone: str) -> dask_expr.Series:
)

def timestamp(self, time_unit: Literal["ns", "us", "ms"] = "us") -> DaskExpr:
def func(
s: dask_expr.Series, time_unit: Literal["ns", "us", "ms"] = "us"
) -> dask_expr.Series:
def func(s: dx.Series, time_unit: Literal["ns", "us", "ms"] = "us") -> dx.Series:
dtype = native_to_narwhals_dtype(
s, self._compliant_expr._version, Implementation.DASK
)
Expand Down
19 changes: 15 additions & 4 deletions narwhals/_dask/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@

if TYPE_CHECKING:
import dask.dataframe as dd
import dask_expr

try:
import dask.dataframe.dask_expr as dx
except ModuleNotFoundError:
import dask_expr as dx

import pandas as pd

from narwhals._dask.dataframe import DaskLazyFrame
Expand Down Expand Up @@ -43,7 +48,10 @@ def var(
]:
from functools import partial

import dask_expr as dx
try:
import dask.dataframe.dask_expr as dx
except ModuleNotFoundError:
import dask_expr as dx

return partial(dx._groupby.GroupBy.var, ddof=ddof)

Expand All @@ -55,7 +63,10 @@ def std(
]:
from functools import partial

import dask_expr as dx
try:
import dask.dataframe.dask_expr as dx
except ModuleNotFoundError:
import dask_expr as dx

return partial(dx._groupby.GroupBy.std, ddof=ddof)

Expand Down Expand Up @@ -127,7 +138,7 @@ def _from_native_frame(self, df: DaskLazyFrame) -> DaskLazyFrame:
def agg_dask(
df: DaskLazyFrame,
grouped: Any,
exprs: Sequence[CompliantExpr[dask_expr.Series]],
exprs: Sequence[CompliantExpr[dx.Series]],
keys: list[str],
from_dataframe: Callable[[Any], DaskLazyFrame],
) -> DaskLazyFrame:
Expand Down
33 changes: 18 additions & 15 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@
from narwhals.typing import CompliantNamespace

if TYPE_CHECKING:
import dask_expr
try:
import dask.dataframe.dask_expr as dx
except ModuleNotFoundError:
import dask_expr as dx

from narwhals._dask.typing import IntoDaskExpr
from narwhals.dtypes import DType
from narwhals.utils import Version


class DaskNamespace(CompliantNamespace["dask_expr.Series"]):
class DaskNamespace(CompliantNamespace["dx.Series"]):
@property
def selectors(self) -> DaskSelectorNamespace:
return DaskSelectorNamespace(
Expand All @@ -40,7 +43,7 @@ def __init__(self, *, backend_version: tuple[int, ...], version: Version) -> Non
self._version = version

def all(self) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
def func(df: DaskLazyFrame) -> list[dx.Series]:
return [df._native_frame[column_name] for column_name in df.columns]

return DaskExpr(
Expand Down Expand Up @@ -69,7 +72,7 @@ def lit(self, value: Any, dtype: DType | None) -> DaskExpr:
import dask.dataframe as dd
import pandas as pd

def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
def func(df: DaskLazyFrame) -> list[dx.Series]:
return [
dd.from_pandas(
pd.Series(
Expand Down Expand Up @@ -99,7 +102,7 @@ def len(self) -> DaskExpr:
import dask.dataframe as dd
import pandas as pd

def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
def func(df: DaskLazyFrame) -> list[dx.Series]:
if not df.columns:
return [
dd.from_pandas(
Expand All @@ -125,7 +128,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
def all_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
def func(df: DaskLazyFrame) -> list[dx.Series]:
series = [s for _expr in parsed_exprs for s in _expr(df)]
return [reduce(lambda x, y: x & y, series).rename(series[0].name)]

Expand All @@ -144,7 +147,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
def any_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
def func(df: DaskLazyFrame) -> list[dx.Series]:
series = [s for _expr in parsed_exprs for s in _expr(df)]
return [reduce(lambda x, y: x | y, series).rename(series[0].name)]

Expand All @@ -163,7 +166,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
def sum_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
def func(df: DaskLazyFrame) -> list[dx.Series]:
series = [s.fillna(0) for _expr in parsed_exprs for s in _expr(df)]
return [reduce(lambda x, y: x + y, series).rename(series[0].name)]

Expand Down Expand Up @@ -239,7 +242,7 @@ def concat(
def mean_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
def func(df: DaskLazyFrame) -> list[dx.Series]:
series = (s.fillna(0) for _expr in parsed_exprs for s in _expr(df))
non_na = (1 - s.isna() for _expr in parsed_exprs for s in _expr(df))
return [
Expand All @@ -266,7 +269,7 @@ def min_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:

parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
def func(df: DaskLazyFrame) -> list[dx.Series]:
series = [s for _expr in parsed_exprs for s in _expr(df)]

return [dd.concat(series, axis=1).min(axis=1).rename(series[0].name)]
Expand All @@ -288,7 +291,7 @@ def max_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:

parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
def func(df: DaskLazyFrame) -> list[dx.Series]:
series = [s for _expr in parsed_exprs for s in _expr(df)]

return [dd.concat(series, axis=1).max(axis=1).rename(series[0].name)]
Expand Down Expand Up @@ -327,7 +330,7 @@ def concat_str(
*parse_into_exprs(*more_exprs, namespace=self),
]

def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
def func(df: DaskLazyFrame) -> list[dx.Series]:
series = (s.astype(str) for _expr in parsed_exprs for s in _expr(df))
null_mask = [s for _expr in parsed_exprs for s in _expr.is_null()(df)]

Expand Down Expand Up @@ -389,20 +392,20 @@ def __init__(
self._returns_scalar = returns_scalar
self._version = version

def __call__(self, df: DaskLazyFrame) -> Sequence[dask_expr.Series]:
def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
from narwhals._expression_parsing import parse_into_expr

plx = df.__narwhals_namespace__()
condition = parse_into_expr(self._condition, namespace=plx)(df)[0]
condition = cast("dask_expr.Series", condition)
condition = cast("dx.Series", condition)
try:
value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
_df = condition.to_frame("a")
_df["tmp"] = self._then_value
value_series = _df["tmp"]
value_series = cast("dask_expr.Series", value_series)
value_series = cast("dx.Series", value_series)
validate_comparand(condition, value_series)

if self._otherwise_value is None:
Expand Down
7 changes: 5 additions & 2 deletions narwhals/_dask/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from narwhals.utils import import_dtypes_module

if TYPE_CHECKING:
import dask_expr
try:
import dask.dataframe.dask_expr as dx
except ModuleNotFoundError:
import dask_expr as dx
from typing_extensions import Self

from narwhals._dask.dataframe import DaskLazyFrame
Expand Down Expand Up @@ -135,7 +138,7 @@ def call(df: DaskLazyFrame) -> list[Any]:
def __or__(self: Self, other: DaskSelector | Any) -> DaskSelector | Any:
if isinstance(other, DaskSelector):

def call(df: DaskLazyFrame) -> list[dask_expr.Series]:
def call(df: DaskLazyFrame) -> list[dx.Series]:
lhs = self._call(df)
rhs = other._call(df)
return [*(x for x in lhs if x.name not in {x.name for x in rhs}), *rhs]
Expand Down
Loading

0 comments on commit 40a83e3

Please sign in to comment.