Skip to content

Commit

Permalink
reduce diff
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Jan 14, 2025
1 parent a7bd1c8 commit c2ed440
Show file tree
Hide file tree
Showing 10 changed files with 24 additions and 25 deletions.
2 changes: 1 addition & 1 deletion narwhals/_expression_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def extract_compliant(


def operation_is_order_dependent(*args: IntoExpr | Any) -> bool:
# If any arg is an Expr, we look at `_is_order_dependent`. If it isn't,
# If an arg is an Expr, we look at `_is_order_dependent`. If it isn't,
# it means that it was a scalar (e.g. nw.col('a') + 1) or a column name,
# neither of which is order-dependent, so we default to `False`.
return any(getattr(x, "_is_order_dependent", False) for x in args)
8 changes: 4 additions & 4 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,16 @@ def _from_compliant_dataframe(self, df: Any) -> Self:
level=self._level,
)

@abstractmethod
def _extract_compliant(self, arg: Any) -> Any:
raise NotImplementedError

def _flatten_and_extract(self, *args: Any, **kwargs: Any) -> Any:
"""Process `args` and `kwargs`, extracting underlying objects as we go."""
args = [self._extract_compliant(v) for v in flatten(args)] # type: ignore[assignment]
kwargs = {k: self._extract_compliant(v) for k, v in kwargs.items()}
return args, kwargs

@abstractmethod
def _extract_compliant(self, arg: Any) -> Any:
raise NotImplementedError

@property
def schema(self) -> Schema:
return Schema(self._compliant_frame.schema.items())
Expand Down
14 changes: 5 additions & 9 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,6 @@ def __init__(
) -> None:
# callable from CompliantNamespace to CompliantExpr
self._to_compliant_expr = to_compliant_expr

# For binary operations, need to do "or".
# For transforms, preserve.
# For aggs, preserve.
self._is_order_dependent = is_order_dependent

def _taxicab_norm(self) -> Self:
Expand Down Expand Up @@ -1236,8 +1232,7 @@ def arg_min(self) -> Self:
b_arg_min: [[1]]
"""
return self.__class__(
lambda plx: self._to_compliant_expr(plx).arg_min(),
is_order_dependent=self._is_order_dependent,
lambda plx: self._to_compliant_expr(plx).arg_min(), is_order_dependent=True
)

def arg_max(self) -> Self:
Expand Down Expand Up @@ -2524,7 +2519,8 @@ def drop_nulls(self) -> Self:
a: [[2,4,3,5]]
"""
return self.__class__(
lambda plx: self._to_compliant_expr(plx).drop_nulls(), is_order_dependent=True
lambda plx: self._to_compliant_expr(plx).drop_nulls(),
self._is_order_dependent,
)

def sample(
Expand Down Expand Up @@ -2598,7 +2594,7 @@ def sample(
lambda plx: self._to_compliant_expr(plx).sample(
n, fraction=fraction, with_replacement=with_replacement, seed=seed
),
is_order_dependent=True,
self._is_order_dependent,
)

def over(self, *keys: str | Iterable[str]) -> Self:
Expand Down Expand Up @@ -3567,7 +3563,7 @@ def mode(self: Self) -> Self:
a: [[1]]
"""
return self.__class__(
lambda plx: self._to_compliant_expr(plx).mode(), is_order_dependent=False
lambda plx: self._to_compliant_expr(plx).mode(), self._is_order_dependent
)

def is_finite(self: Self) -> Self:
Expand Down
1 change: 0 additions & 1 deletion tests/expr_and_series/convert_time_zone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from tests.utils import POLARS_VERSION
from tests.utils import PYARROW_VERSION
from tests.utils import Constructor
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data
from tests.utils import is_windows

Expand Down
4 changes: 1 addition & 3 deletions tests/expr_and_series/replace_strict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
)
@pytest.mark.parametrize("return_dtype", [nw.String(), None])
def test_replace_strict(
constructor: Constructor,
request: pytest.FixtureRequest,
return_dtype: DType | None,
constructor: Constructor, request: pytest.FixtureRequest, return_dtype: DType | None
) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
Expand Down
1 change: 0 additions & 1 deletion tests/expr_and_series/replace_time_zone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from tests.utils import PANDAS_VERSION
from tests.utils import PYARROW_VERSION
from tests.utils import Constructor
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data
from tests.utils import is_windows

Expand Down
2 changes: 1 addition & 1 deletion tests/expr_and_series/str/to_datetime_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

import narwhals.stable.v1 as nw
from narwhals._arrow.utils import parse_datetime_format
from tests.utils import Constructor
from tests.utils import assert_equal_data

if TYPE_CHECKING:
from tests.utils import Constructor
from tests.utils import ConstructorEager

data = {"a": ["2020-01-01T12:34:56"]}
Expand Down
5 changes: 4 additions & 1 deletion tests/frame/drop_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from contextlib import nullcontext as does_not_raise
from typing import TYPE_CHECKING
from typing import Any

import pytest
Expand All @@ -9,7 +10,9 @@
import narwhals.stable.v1 as nw
from narwhals.exceptions import ColumnNotFoundError
from tests.utils import POLARS_VERSION
from tests.utils import Constructor

if TYPE_CHECKING:
from tests.utils import Constructor


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/frame/schema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

import narwhals.stable.v1 as nw
from tests.utils import PANDAS_VERSION
from tests.utils import Constructor

if TYPE_CHECKING:
from tests.utils import Constructor
from tests.utils import ConstructorEager


Expand Down
10 changes: 7 additions & 3 deletions utils/check_api_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@
for namespace in NAMESPACES:
expr_methods = [
i
for i in getattr(nw.Expr(lambda: 0, is_order_dependent=True), namespace).__dir__()
for i in getattr(
nw.Expr(lambda: 0, is_order_dependent=False), namespace
).__dir__()
if not i[0].isupper() and i[0] != "_"
]
with open(f"docs/api-reference/expr_{namespace}.md") as fd:
Expand Down Expand Up @@ -226,7 +228,7 @@
# Check Expr vs Series
expr = [
i
for i in nw.Expr(lambda: 0, is_order_dependent=True).__dir__()
for i in nw.Expr(lambda: 0, is_order_dependent=False).__dir__()
if not i[0].isupper() and i[0] != "_"
]
series = [
Expand All @@ -247,7 +249,9 @@
for namespace in NAMESPACES.difference({"name"}):
expr_internal = [
i
for i in getattr(nw.Expr(lambda: 0, is_order_dependent=True), namespace).__dir__()
for i in getattr(
nw.Expr(lambda: 0, is_order_dependent=False), namespace
).__dir__()
if not i[0].isupper() and i[0] != "_"
]
series_internal = [
Expand Down

0 comments on commit c2ed440

Please sign in to comment.