diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 3ce3c1c89..dc52da002 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -338,7 +338,7 @@ def extract_compliant( def operation_is_order_dependent(*args: IntoExpr | Any) -> bool: - # If any arg is an Expr, we look at `_is_order_dependent`. If it isn't, + # If an arg is an Expr, we look at `_is_order_dependent`. If it isn't, # it means that it was a scalar (e.g. nw.col('a') + 1) or a column name, # neither of which is order-dependent, so we default to `False`. return any(getattr(x, "_is_order_dependent", False) for x in args) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 7fef829fa..99e037c43 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -65,16 +65,16 @@ def _from_compliant_dataframe(self, df: Any) -> Self: level=self._level, ) - @abstractmethod - def _extract_compliant(self, arg: Any) -> Any: - raise NotImplementedError - def _flatten_and_extract(self, *args: Any, **kwargs: Any) -> Any: """Process `args` and `kwargs`, extracting underlying objects as we go.""" args = [self._extract_compliant(v) for v in flatten(args)] # type: ignore[assignment] kwargs = {k: self._extract_compliant(v) for k, v in kwargs.items()} return args, kwargs + @abstractmethod + def _extract_compliant(self, arg: Any) -> Any: + raise NotImplementedError + @property def schema(self) -> Schema: return Schema(self._compliant_frame.schema.items()) diff --git a/narwhals/expr.py b/narwhals/expr.py index 48b0fb4b2..a2831a880 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -36,10 +36,6 @@ def __init__( ) -> None: # callable from CompliantNamespace to CompliantExpr self._to_compliant_expr = to_compliant_expr - - # For binary operations, need to do "or". - # For transforms, preserve. - # For aggs, preserve. self._is_order_dependent = is_order_dependent def _taxicab_norm(self) -> Self: @@ -1236,8 +1232,7 @@ def arg_min(self) -> Self: b_arg_min: [[1]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).arg_min(), - is_order_dependent=self._is_order_dependent, + lambda plx: self._to_compliant_expr(plx).arg_min(), is_order_dependent=True ) def arg_max(self) -> Self: @@ -2524,7 +2519,8 @@ def drop_nulls(self) -> Self: a: [[2,4,3,5]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).drop_nulls(), is_order_dependent=True + lambda plx: self._to_compliant_expr(plx).drop_nulls(), + self._is_order_dependent, ) def sample( @@ -2598,7 +2594,7 @@ def sample( lambda plx: self._to_compliant_expr(plx).sample( n, fraction=fraction, with_replacement=with_replacement, seed=seed ), - is_order_dependent=True, + self._is_order_dependent, ) def over(self, *keys: str | Iterable[str]) -> Self: @@ -3567,7 +3563,7 @@ def mode(self: Self) -> Self: a: [[1]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).mode(), is_order_dependent=False + lambda plx: self._to_compliant_expr(plx).mode(), self._is_order_dependent ) def is_finite(self: Self) -> Self: diff --git a/tests/expr_and_series/convert_time_zone_test.py b/tests/expr_and_series/convert_time_zone_test.py index a0bce5a40..9a18ee07f 100644 --- a/tests/expr_and_series/convert_time_zone_test.py +++ b/tests/expr_and_series/convert_time_zone_test.py @@ -11,7 +11,6 @@ from tests.utils import POLARS_VERSION from tests.utils import PYARROW_VERSION from tests.utils import Constructor -from tests.utils import ConstructorEager from tests.utils import assert_equal_data from tests.utils import is_windows diff --git a/tests/expr_and_series/replace_strict_test.py b/tests/expr_and_series/replace_strict_test.py index c033ad0ce..33c56bae6 100644 --- a/tests/expr_and_series/replace_strict_test.py +++ b/tests/expr_and_series/replace_strict_test.py @@ -19,9 +19,7 @@ ) @pytest.mark.parametrize("return_dtype", [nw.String(), None]) def test_replace_strict( - constructor: Constructor, - request: pytest.FixtureRequest, - return_dtype: DType | None, + constructor: Constructor, request: pytest.FixtureRequest, return_dtype: DType | None ) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) diff --git a/tests/expr_and_series/replace_time_zone_test.py b/tests/expr_and_series/replace_time_zone_test.py index 8549af69b..6876c318a 100644 --- a/tests/expr_and_series/replace_time_zone_test.py +++ b/tests/expr_and_series/replace_time_zone_test.py @@ -10,7 +10,6 @@ from tests.utils import PANDAS_VERSION from tests.utils import PYARROW_VERSION from tests.utils import Constructor -from tests.utils import ConstructorEager from tests.utils import assert_equal_data from tests.utils import is_windows diff --git a/tests/expr_and_series/str/to_datetime_test.py b/tests/expr_and_series/str/to_datetime_test.py index e008bfafc..bfb2a4dfb 100644 --- a/tests/expr_and_series/str/to_datetime_test.py +++ b/tests/expr_and_series/str/to_datetime_test.py @@ -8,10 +8,10 @@ import narwhals.stable.v1 as nw from narwhals._arrow.utils import parse_datetime_format -from tests.utils import Constructor from tests.utils import assert_equal_data if TYPE_CHECKING: + from tests.utils import Constructor from tests.utils import ConstructorEager data = {"a": ["2020-01-01T12:34:56"]} diff --git a/tests/frame/drop_test.py b/tests/frame/drop_test.py index ac3f0c2ce..eb9bb2660 100644 --- a/tests/frame/drop_test.py +++ b/tests/frame/drop_test.py @@ -1,6 +1,7 @@ from __future__ import annotations from contextlib import nullcontext as does_not_raise +from typing import TYPE_CHECKING from typing import Any import pytest @@ -9,7 +10,9 @@ import narwhals.stable.v1 as nw from narwhals.exceptions import ColumnNotFoundError from tests.utils import POLARS_VERSION -from tests.utils import Constructor + +if TYPE_CHECKING: + from tests.utils import Constructor @pytest.mark.parametrize( diff --git a/tests/frame/schema_test.py b/tests/frame/schema_test.py index 30735db1e..565bf0159 100644 --- a/tests/frame/schema_test.py +++ b/tests/frame/schema_test.py @@ -13,9 +13,9 @@ import narwhals.stable.v1 as nw from tests.utils import PANDAS_VERSION -from tests.utils import Constructor if TYPE_CHECKING: + from tests.utils import Constructor from tests.utils import ConstructorEager diff --git a/utils/check_api_reference.py b/utils/check_api_reference.py index a244ca8e0..8df16e88b 100644 --- a/utils/check_api_reference.py +++ b/utils/check_api_reference.py @@ -184,7 +184,9 @@ for namespace in NAMESPACES: expr_methods = [ i - for i in getattr(nw.Expr(lambda: 0, is_order_dependent=True), namespace).__dir__() + for i in getattr( + nw.Expr(lambda: 0, is_order_dependent=False), namespace + ).__dir__() if not i[0].isupper() and i[0] != "_" ] with open(f"docs/api-reference/expr_{namespace}.md") as fd: @@ -226,7 +228,7 @@ # Check Expr vs Series expr = [ i - for i in nw.Expr(lambda: 0, is_order_dependent=True).__dir__() + for i in nw.Expr(lambda: 0, is_order_dependent=False).__dir__() if not i[0].isupper() and i[0] != "_" ] series = [ @@ -247,7 +249,9 @@ for namespace in NAMESPACES.difference({"name"}): expr_internal = [ i - for i in getattr(nw.Expr(lambda: 0, is_order_dependent=True), namespace).__dir__() + for i in getattr( + nw.Expr(lambda: 0, is_order_dependent=False), namespace + ).__dir__() if not i[0].isupper() and i[0] != "_" ] series_internal = [