From 40a83e36f41389db59881cc211ed83ec1b6913f6 Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Thu, 9 Jan 2025 10:11:17 +0100 Subject: [PATCH] chore: dask nightly (#1768) --- narwhals/_dask/expr.py | 48 +++++++++++++++++++------------------ narwhals/_dask/group_by.py | 19 +++++++++++---- narwhals/_dask/namespace.py | 33 +++++++++++++------------ narwhals/_dask/selectors.py | 7 ++++-- narwhals/_dask/utils.py | 21 ++++++++++------ narwhals/translate.py | 5 +++- 6 files changed, 81 insertions(+), 52 deletions(-) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 40e7eff9c..373c29020 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -23,7 +23,11 @@ from narwhals.utils import import_dtypes_module if TYPE_CHECKING: - import dask_expr + try: + import dask.dataframe.dask_expr as dx + except ModuleNotFoundError: + import dask_expr as dx + from typing_extensions import Self from narwhals._dask.dataframe import DaskLazyFrame @@ -32,12 +36,12 @@ from narwhals.utils import Version -class DaskExpr(CompliantExpr["dask_expr.Series"]): +class DaskExpr(CompliantExpr["dx.Series"]): _implementation: Implementation = Implementation.DASK def __init__( self, - call: Callable[[DaskLazyFrame], Sequence[dask_expr.Series]], + call: Callable[[DaskLazyFrame], Sequence[dx.Series]], *, depth: int, function_name: str, @@ -60,7 +64,7 @@ def __init__( self._version = version self._kwargs = kwargs - def __call__(self, df: DaskLazyFrame) -> Sequence[dask_expr.Series]: + def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]: return self._call(df) def __narwhals_expr__(self) -> None: ... @@ -78,7 +82,7 @@ def from_column_names( backend_version: tuple[int, ...], version: Version, ) -> Self: - def func(df: DaskLazyFrame) -> list[dask_expr.Series]: + def func(df: DaskLazyFrame) -> list[dx.Series]: try: return [df._native_frame[column_name] for column_name in column_names] except KeyError as e: @@ -107,7 +111,7 @@ def from_column_indices( backend_version: tuple[int, ...], version: Version, ) -> Self: - def func(df: DaskLazyFrame) -> list[dask_expr.Series]: + def func(df: DaskLazyFrame) -> list[dx.Series]: return [ df._native_frame.iloc[:, column_index] for column_index in column_indices ] @@ -126,14 +130,14 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: def _from_call( self, - # First argument to `call` should be `dask_expr.Series` - call: Callable[..., dask_expr.Series], + # First argument to `call` should be `dx.Series` + call: Callable[..., dx.Series], expr_name: str, *, returns_scalar: bool, **kwargs: Any, ) -> Self: - def func(df: DaskLazyFrame) -> list[dask_expr.Series]: + def func(df: DaskLazyFrame) -> list[dx.Series]: results = [] inputs = self._call(df) _kwargs = {key: maybe_evaluate(df, value) for key, value in kwargs.items()} @@ -163,7 +167,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: ) def alias(self, name: str) -> Self: - def func(df: DaskLazyFrame) -> list[dask_expr.Series]: + def func(df: DaskLazyFrame) -> list[dx.Series]: inputs = self._call(df) return [_input.rename(name) for _input in inputs] @@ -312,7 +316,7 @@ def mean(self) -> Self: def median(self) -> Self: from narwhals.exceptions import InvalidOperationError - def func(s: dask_expr.Series) -> dask_expr.Series: + def func(s: dx.Series) -> dx.Series: dtype = native_to_narwhals_dtype(s, self._version, Implementation.DASK) if not dtype.is_numeric(): msg = "`median` operation not supported for non-numeric input type." @@ -511,11 +515,11 @@ def fill_null( limit: int | None = None, ) -> DaskExpr: def func( - _input: dask_expr.Series, + _input: dx.Series, value: Any | None, strategy: str | None, limit: int | None, - ) -> dask_expr.Series: + ) -> dx.Series: if value is not None: res_ser = _input.fillna(value) else: @@ -566,7 +570,7 @@ def is_null(self: Self) -> Self: ) def is_nan(self: Self) -> Self: - def func(_input: dask_expr.Series) -> dask_expr.Series: + def func(_input: dx.Series) -> dx.Series: dtype = native_to_narwhals_dtype(_input, self._version, self._implementation) if dtype.is_numeric(): return _input != _input # noqa: PLR0124 @@ -585,7 +589,7 @@ def quantile( ) -> Self: if interpolation == "linear": - def func(_input: dask_expr.Series, quantile: float) -> dask_expr.Series: + def func(_input: dx.Series, quantile: float) -> dx.Series: if _input.npartitions > 1: msg = "`Expr.quantile` is not supported for Dask backend with multiple partitions." raise NotImplementedError(msg) @@ -599,7 +603,7 @@ def func(_input: dask_expr.Series, quantile: float) -> dask_expr.Series: raise NotImplementedError(msg) def is_first_distinct(self: Self) -> Self: - def func(_input: dask_expr.Series) -> dask_expr.Series: + def func(_input: dx.Series) -> dx.Series: _name = _input.name col_token = generate_temporary_column_name(n_bytes=8, columns=[_name]) _input = add_row_index( @@ -618,7 +622,7 @@ def func(_input: dask_expr.Series) -> dask_expr.Series: ) def is_last_distinct(self: Self) -> Self: - def func(_input: dask_expr.Series) -> dask_expr.Series: + def func(_input: dx.Series) -> dx.Series: _name = _input.name col_token = generate_temporary_column_name(n_bytes=8, columns=[_name]) _input = add_row_index( @@ -635,7 +639,7 @@ def func(_input: dask_expr.Series) -> dask_expr.Series: ) def is_duplicated(self: Self) -> Self: - def func(_input: dask_expr.Series) -> dask_expr.Series: + def func(_input: dx.Series) -> dx.Series: _name = _input.name return ( _input.to_frame() @@ -647,7 +651,7 @@ def func(_input: dask_expr.Series) -> dask_expr.Series: return self._from_call(func, "is_duplicated", returns_scalar=self._returns_scalar) def is_unique(self: Self) -> Self: - def func(_input: dask_expr.Series) -> dask_expr.Series: + def func(_input: dx.Series) -> dx.Series: _name = _input.name return ( _input.to_frame() @@ -967,7 +971,7 @@ def replace_time_zone(self, time_zone: str | None) -> DaskExpr: ) def convert_time_zone(self, time_zone: str) -> DaskExpr: - def func(s: dask_expr.Series, time_zone: str) -> dask_expr.Series: + def func(s: dx.Series, time_zone: str) -> dx.Series: dtype = native_to_narwhals_dtype( s, self._compliant_expr._version, Implementation.DASK ) @@ -984,9 +988,7 @@ def func(s: dask_expr.Series, time_zone: str) -> dask_expr.Series: ) def timestamp(self, time_unit: Literal["ns", "us", "ms"] = "us") -> DaskExpr: - def func( - s: dask_expr.Series, time_unit: Literal["ns", "us", "ms"] = "us" - ) -> dask_expr.Series: + def func(s: dx.Series, time_unit: Literal["ns", "us", "ms"] = "us") -> dx.Series: dtype = native_to_narwhals_dtype( s, self._compliant_expr._version, Implementation.DASK ) diff --git a/narwhals/_dask/group_by.py b/narwhals/_dask/group_by.py index 243b21b71..60086efa2 100644 --- a/narwhals/_dask/group_by.py +++ b/narwhals/_dask/group_by.py @@ -12,7 +12,12 @@ if TYPE_CHECKING: import dask.dataframe as dd - import dask_expr + + try: + import dask.dataframe.dask_expr as dx + except ModuleNotFoundError: + import dask_expr as dx + import pandas as pd from narwhals._dask.dataframe import DaskLazyFrame @@ -43,7 +48,10 @@ def var( ]: from functools import partial - import dask_expr as dx + try: + import dask.dataframe.dask_expr as dx + except ModuleNotFoundError: + import dask_expr as dx return partial(dx._groupby.GroupBy.var, ddof=ddof) @@ -55,7 +63,10 @@ def std( ]: from functools import partial - import dask_expr as dx + try: + import dask.dataframe.dask_expr as dx + except ModuleNotFoundError: + import dask_expr as dx return partial(dx._groupby.GroupBy.std, ddof=ddof) @@ -127,7 +138,7 @@ def _from_native_frame(self, df: DaskLazyFrame) -> DaskLazyFrame: def agg_dask( df: DaskLazyFrame, grouped: Any, - exprs: Sequence[CompliantExpr[dask_expr.Series]], + exprs: Sequence[CompliantExpr[dx.Series]], keys: list[str], from_dataframe: Callable[[Any], DaskLazyFrame], ) -> DaskLazyFrame: diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 9a16d7f13..d8b2b7a9a 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -21,14 +21,17 @@ from narwhals.typing import CompliantNamespace if TYPE_CHECKING: - import dask_expr + try: + import dask.dataframe.dask_expr as dx + except ModuleNotFoundError: + import dask_expr as dx from narwhals._dask.typing import IntoDaskExpr from narwhals.dtypes import DType from narwhals.utils import Version -class DaskNamespace(CompliantNamespace["dask_expr.Series"]): +class DaskNamespace(CompliantNamespace["dx.Series"]): @property def selectors(self) -> DaskSelectorNamespace: return DaskSelectorNamespace( @@ -40,7 +43,7 @@ def __init__(self, *, backend_version: tuple[int, ...], version: Version) -> Non self._version = version def all(self) -> DaskExpr: - def func(df: DaskLazyFrame) -> list[dask_expr.Series]: + def func(df: DaskLazyFrame) -> list[dx.Series]: return [df._native_frame[column_name] for column_name in df.columns] return DaskExpr( @@ -69,7 +72,7 @@ def lit(self, value: Any, dtype: DType | None) -> DaskExpr: import dask.dataframe as dd import pandas as pd - def func(df: DaskLazyFrame) -> list[dask_expr.Series]: + def func(df: DaskLazyFrame) -> list[dx.Series]: return [ dd.from_pandas( pd.Series( @@ -99,7 +102,7 @@ def len(self) -> DaskExpr: import dask.dataframe as dd import pandas as pd - def func(df: DaskLazyFrame) -> list[dask_expr.Series]: + def func(df: DaskLazyFrame) -> list[dx.Series]: if not df.columns: return [ dd.from_pandas( @@ -125,7 +128,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: def all_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: parsed_exprs = parse_into_exprs(*exprs, namespace=self) - def func(df: DaskLazyFrame) -> list[dask_expr.Series]: + def func(df: DaskLazyFrame) -> list[dx.Series]: series = [s for _expr in parsed_exprs for s in _expr(df)] return [reduce(lambda x, y: x & y, series).rename(series[0].name)] @@ -144,7 +147,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: def any_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: parsed_exprs = parse_into_exprs(*exprs, namespace=self) - def func(df: DaskLazyFrame) -> list[dask_expr.Series]: + def func(df: DaskLazyFrame) -> list[dx.Series]: series = [s for _expr in parsed_exprs for s in _expr(df)] return [reduce(lambda x, y: x | y, series).rename(series[0].name)] @@ -163,7 +166,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: def sum_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: parsed_exprs = parse_into_exprs(*exprs, namespace=self) - def func(df: DaskLazyFrame) -> list[dask_expr.Series]: + def func(df: DaskLazyFrame) -> list[dx.Series]: series = [s.fillna(0) for _expr in parsed_exprs for s in _expr(df)] return [reduce(lambda x, y: x + y, series).rename(series[0].name)] @@ -239,7 +242,7 @@ def concat( def mean_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: parsed_exprs = parse_into_exprs(*exprs, namespace=self) - def func(df: DaskLazyFrame) -> list[dask_expr.Series]: + def func(df: DaskLazyFrame) -> list[dx.Series]: series = (s.fillna(0) for _expr in parsed_exprs for s in _expr(df)) non_na = (1 - s.isna() for _expr in parsed_exprs for s in _expr(df)) return [ @@ -266,7 +269,7 @@ def min_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: parsed_exprs = parse_into_exprs(*exprs, namespace=self) - def func(df: DaskLazyFrame) -> list[dask_expr.Series]: + def func(df: DaskLazyFrame) -> list[dx.Series]: series = [s for _expr in parsed_exprs for s in _expr(df)] return [dd.concat(series, axis=1).min(axis=1).rename(series[0].name)] @@ -288,7 +291,7 @@ def max_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: parsed_exprs = parse_into_exprs(*exprs, namespace=self) - def func(df: DaskLazyFrame) -> list[dask_expr.Series]: + def func(df: DaskLazyFrame) -> list[dx.Series]: series = [s for _expr in parsed_exprs for s in _expr(df)] return [dd.concat(series, axis=1).max(axis=1).rename(series[0].name)] @@ -327,7 +330,7 @@ def concat_str( *parse_into_exprs(*more_exprs, namespace=self), ] - def func(df: DaskLazyFrame) -> list[dask_expr.Series]: + def func(df: DaskLazyFrame) -> list[dx.Series]: series = (s.astype(str) for _expr in parsed_exprs for s in _expr(df)) null_mask = [s for _expr in parsed_exprs for s in _expr.is_null()(df)] @@ -389,12 +392,12 @@ def __init__( self._returns_scalar = returns_scalar self._version = version - def __call__(self, df: DaskLazyFrame) -> Sequence[dask_expr.Series]: + def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]: from narwhals._expression_parsing import parse_into_expr plx = df.__narwhals_namespace__() condition = parse_into_expr(self._condition, namespace=plx)(df)[0] - condition = cast("dask_expr.Series", condition) + condition = cast("dx.Series", condition) try: value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0] except TypeError: @@ -402,7 +405,7 @@ def __call__(self, df: DaskLazyFrame) -> Sequence[dask_expr.Series]: _df = condition.to_frame("a") _df["tmp"] = self._then_value value_series = _df["tmp"] - value_series = cast("dask_expr.Series", value_series) + value_series = cast("dx.Series", value_series) validate_comparand(condition, value_series) if self._otherwise_value is None: diff --git a/narwhals/_dask/selectors.py b/narwhals/_dask/selectors.py index 2891d84ff..703e24860 100644 --- a/narwhals/_dask/selectors.py +++ b/narwhals/_dask/selectors.py @@ -8,7 +8,10 @@ from narwhals.utils import import_dtypes_module if TYPE_CHECKING: - import dask_expr + try: + import dask.dataframe.dask_expr as dx + except ModuleNotFoundError: + import dask_expr as dx from typing_extensions import Self from narwhals._dask.dataframe import DaskLazyFrame @@ -135,7 +138,7 @@ def call(df: DaskLazyFrame) -> list[Any]: def __or__(self: Self, other: DaskSelector | Any) -> DaskSelector | Any: if isinstance(other, DaskSelector): - def call(df: DaskLazyFrame) -> list[dask_expr.Series]: + def call(df: DaskLazyFrame) -> list[dx.Series]: lhs = self._call(df) rhs = other._call(df) return [*(x for x in lhs if x.name not in {x.name for x in rhs}), *rhs] diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py index 4f2952d0b..cd303d8ec 100644 --- a/narwhals/_dask/utils.py +++ b/narwhals/_dask/utils.py @@ -14,7 +14,11 @@ if TYPE_CHECKING: import dask.dataframe as dd - import dask_expr + + try: + import dask.dataframe.dask_expr as dx + except ModuleNotFoundError: + import dask_expr as dx from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.expr import DaskExpr @@ -42,7 +46,7 @@ def maybe_evaluate(df: DaskLazyFrame, obj: Any) -> Any: def parse_exprs_and_named_exprs( df: DaskLazyFrame, *exprs: Any, **named_exprs: Any -) -> dict[str, dask_expr.Series]: +) -> dict[str, dx.Series]: results = {} for expr in exprs: if hasattr(expr, "__narwhals_expr__"): @@ -82,10 +86,13 @@ def add_row_index( ) -def validate_comparand(lhs: dask_expr.Series, rhs: dask_expr.Series) -> None: - import dask_expr +def validate_comparand(lhs: dx.Series, rhs: dx.Series) -> None: + try: + import dask.dataframe.dask_expr as dx + except ModuleNotFoundError: + import dask_expr as dx - if not dask_expr._expr.are_co_aligned(lhs._expr, rhs._expr): # pragma: no cover + if not dx._expr.are_co_aligned(lhs._expr, rhs._expr): # pragma: no cover # are_co_aligned is a method which cheaply checks if two Dask expressions # have the same index, and therefore don't require index alignment. # If someone only operates on a Dask DataFrame via expressions, then this @@ -154,11 +161,11 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> An raise AssertionError(msg) -def name_preserving_sum(s1: dask_expr.Series, s2: dask_expr.Series) -> dask_expr.Series: +def name_preserving_sum(s1: dx.Series, s2: dx.Series) -> dx.Series: return (s1 + s2).rename(s1.name) -def name_preserving_div(s1: dask_expr.Series, s2: dask_expr.Series) -> dask_expr.Series: +def name_preserving_div(s1: dx.Series, s2: dx.Series) -> dx.Series: return (s1 / s2).rename(s1.name) diff --git a/narwhals/translate.py b/narwhals/translate.py index 8d0805a26..9ad868016 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -684,7 +684,10 @@ def _from_native_impl( # noqa: PLR0915 msg = "Cannot only use `eager_only` or `eager_or_interchange_only` with dask DataFrame" raise TypeError(msg) return native_object - if get_dask_expr() is None: # pragma: no cover + if ( + parse_version(get_dask().__version__) <= (2024, 12, 1) + and get_dask_expr() is None + ): # pragma: no cover msg = "Please install dask-expr" raise ImportError(msg) return LazyFrame(