Skip to content

Commit

Permalink
Merge branch 'main' into add-test-formulatic
Browse files Browse the repository at this point in the history
  • Loading branch information
luke396 authored Jan 17, 2025
2 parents e3ac48a + dd9607e commit e31deff
Show file tree
Hide file tree
Showing 22 changed files with 378 additions and 335 deletions.
22 changes: 16 additions & 6 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,11 @@ def with_columns(
col_name = col_value.name

column = validate_dataframe_comparand(
length=length, other=col_value, backend_version=self._backend_version
length=length,
other=col_value,
backend_version=self._backend_version,
allow_broadcast=True,
method_name="with_columns",
)

native_frame = (
Expand Down Expand Up @@ -395,10 +399,9 @@ def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
)
return self._from_native_frame(self._native_frame.drop(to_drop))

def drop_nulls(self: Self, subset: str | list[str] | None) -> Self:
def drop_nulls(self: Self, subset: list[str] | None) -> Self:
if subset is None:
return self._from_native_frame(self._native_frame.drop_null())
subset = [subset] if isinstance(subset, str) else subset
plx = self.__narwhals_namespace__()
return self.filter(~plx.any_horizontal(plx.col(*subset).is_null()))

Expand Down Expand Up @@ -483,7 +486,7 @@ def filter(self: Self, *predicates: IntoArrowExpr, **constraints: Any) -> Self:
and all(isinstance(x, bool) for x in predicates[0])
and not constraints
):
mask = predicates[0]
mask_native = predicates[0]
else:
plx = self.__narwhals_namespace__()
expr = plx.all_horizontal(
Expand All @@ -492,8 +495,15 @@ def filter(self: Self, *predicates: IntoArrowExpr, **constraints: Any) -> Self:
)
)
# `[0]` is safe as all_horizontal's expression only returns a single column
mask = expr._call(self)[0]._native_series
return self._from_native_frame(self._native_frame.filter(mask))
mask = expr._call(self)[0]
mask_native = validate_dataframe_comparand(
length=len(self),
other=mask,
backend_version=self._backend_version,
allow_broadcast=False,
method_name="filter",
)
return self._from_native_frame(self._native_frame.filter(mask_native))

def null_count(self: Self) -> Self:
import pyarrow as pa
Expand Down
19 changes: 8 additions & 11 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,19 +435,18 @@ def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]:

plx = df.__narwhals_namespace__()
condition = parse_into_expr(self._condition, namespace=plx)(df)[0]

try:
value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
value_series = condition.__class__._from_iterable(
pa.repeat(pa.scalar(self._then_value), len(condition)),
name="literal",
backend_version=self._backend_version,
version=self._version,
# `self._then_value` is a scalar and can't be converted to an expression
value_series = plx._create_series_from_scalar(
self._then_value, reference_series=condition
)

value_series_native = value_series._native_series
condition_native = condition._native_series
condition_native, value_series_native = broadcast_series(
[condition, value_series]
)

if self._otherwise_value is None:
otherwise_native = pa.repeat(
Expand All @@ -472,9 +471,7 @@ def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]:
]
else:
otherwise_series = otherwise_expr(df)[0]
condition_native, otherwise_native = broadcast_series(
[condition, otherwise_series]
)
_, otherwise_native = broadcast_series([condition, otherwise_series])
return [
value_series._from_native_series(
pc.if_else(condition_native, value_series_native, otherwise_native)
Expand Down
22 changes: 20 additions & 2 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Sequence
from typing import overload

from narwhals.exceptions import ShapeError
from narwhals.utils import import_dtypes_module
from narwhals.utils import isinstance_or_issubclass

Expand Down Expand Up @@ -212,7 +213,12 @@ def broadcast_and_extract_native(


def validate_dataframe_comparand(
length: int, other: Any, backend_version: tuple[int, ...]
length: int,
other: Any,
backend_version: tuple[int, ...],
*,
allow_broadcast: bool,
method_name: str,
) -> Any:
"""Validate RHS of binary operation.
Expand All @@ -222,14 +228,26 @@ def validate_dataframe_comparand(
from narwhals._arrow.series import ArrowSeries

if isinstance(other, ArrowSeries):
if len(other) == 1:
len_other = len(other)
if len_other == 1:
if length > 1 and not allow_broadcast:
msg = (
f"{method_name}'s length: 1 differs from that of the series: {length}"
)
raise ShapeError(msg)

import numpy as np # ignore-banned-import
import pyarrow as pa

value = other._native_series[0]
if backend_version < (13,) and hasattr(value, "as_py"):
value = value.as_py()
return pa.array(np.full(shape=length, fill_value=value))

if length != len_other:
msg = f"{method_name}'s length: {len_other} differs from that of the series: {length}"
raise ShapeError(msg)

return other._native_series

from narwhals._arrow.dataframe import ArrowDataFrame # pragma: no cover
Expand Down
6 changes: 3 additions & 3 deletions narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,14 @@ def collect(self) -> Any:
def columns(self) -> list[str]:
return self._native_frame.columns.tolist() # type: ignore[no-any-return]

def filter(self, *predicates: DaskExpr, **constraints: Any) -> Self:
def filter(self: Self, *predicates: DaskExpr, **constraints: Any) -> Self:
plx = self.__narwhals_namespace__()
expr = plx.all_horizontal(
*chain(predicates, (plx.col(name) == v for name, v in constraints.items()))
)
# `[0]` is safe as all_horizontal's expression only returns a single column
mask = expr._call(self)[0]

return self._from_native_frame(self._native_frame.loc[mask])

def select(
Expand Down Expand Up @@ -149,10 +150,9 @@ def select(
)
return self._from_native_frame(df)

def drop_nulls(self: Self, subset: str | list[str] | None) -> Self:
def drop_nulls(self: Self, subset: list[str] | None) -> Self:
if subset is None:
return self._from_native_frame(self._native_frame.dropna())
subset = [subset] if isinstance(subset, str) else subset
plx = self.__narwhals_namespace__()
return self.filter(~plx.any_horizontal(plx.col(*subset).is_null()))

Expand Down
17 changes: 14 additions & 3 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,13 +398,24 @@ def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
plx = df.__narwhals_namespace__()
condition = parse_into_expr(self._condition, namespace=plx)(df)[0]
condition = cast("dx.Series", condition)

try:
value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0]
then_expr = parse_into_expr(self._then_value, namespace=plx)
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
# `self._then_value` is a scalar and can't be converted to an expression
value_sequence: Sequence[Any] = [self._then_value]
is_scalar = True
else:
is_scalar = then_expr._returns_scalar # type: ignore[attr-defined]
value_sequence = then_expr(df)[0]

if is_scalar:
_df = condition.to_frame("a")
_df["tmp"] = self._then_value
_df["tmp"] = value_sequence[0]
value_series = _df["tmp"]
else:
value_series = value_sequence

value_series = cast("dx.Series", value_series)
validate_comparand(condition, value_series)

Expand Down
11 changes: 10 additions & 1 deletion narwhals/_duckdb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def with_columns(
result.append(value.alias(col))
return self._from_native_frame(self._native_frame.select(*result))

def filter(self, *predicates: DuckDBExpr, **constraints: Any) -> Self:
def filter(self: Self, *predicates: DuckDBExpr, **constraints: Any) -> Self:
plx = self.__narwhals_namespace__()
expr = plx.all_horizontal(
*chain(predicates, (plx.col(name) == v for name, v in constraints.items()))
Expand Down Expand Up @@ -321,3 +321,12 @@ def sort(
)
)
return self._from_native_frame(result)

def drop_nulls(self: Self, subset: list[str] | None) -> Self:
import duckdb

rel = self._native_frame
subset_ = subset if subset is not None else rel.columns
keep_condition = " and ".join(f'"{col}" is not null' for col in subset_)
query = f"select * from rel where {keep_condition}" # noqa: S608
return self._from_native_frame(duckdb.sql(query))
29 changes: 21 additions & 8 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,9 @@ def select(
)
return self._from_native_frame(df)

def drop_nulls(self, subset: str | list[str] | None) -> Self:
def drop_nulls(self, subset: list[str] | None) -> Self:
if subset is None:
return self._from_native_frame(self._native_frame.dropna(axis=0))
subset = [subset] if isinstance(subset, str) else subset
plx = self.__narwhals_namespace__()
return self.filter(~plx.any_horizontal(plx.col(*subset).is_null()))

Expand All @@ -401,15 +400,15 @@ def with_row_index(self, name: str) -> Self:
def row(self, row: int) -> tuple[Any, ...]:
return tuple(x for x in self._native_frame.iloc[row])

def filter(self, *predicates: IntoPandasLikeExpr, **constraints: Any) -> Self:
def filter(self: Self, *predicates: IntoPandasLikeExpr, **constraints: Any) -> Self:
plx = self.__narwhals_namespace__()
if (
len(predicates) == 1
and isinstance(predicates[0], list)
and all(isinstance(x, bool) for x in predicates[0])
and not constraints
):
_mask = predicates[0]
mask_native = predicates[0]
else:
expr = plx.all_horizontal(
*chain(
Expand All @@ -418,8 +417,14 @@ def filter(self, *predicates: IntoPandasLikeExpr, **constraints: Any) -> Self:
)
# `[0]` is safe as all_horizontal's expression only returns a single column
mask = expr._call(self)[0]
_mask = validate_dataframe_comparand(self._native_frame.index, mask)
return self._from_native_frame(self._native_frame.loc[_mask])
mask_native = validate_dataframe_comparand(
self._native_frame.index,
mask,
allow_broadcast=False,
method_name="filter",
)

return self._from_native_frame(self._native_frame.loc[mask_native])

def with_columns(
self,
Expand All @@ -438,13 +443,21 @@ def with_columns(
if name in new_column_name_to_new_column_map:
to_concat.append(
validate_dataframe_comparand(
index, new_column_name_to_new_column_map.pop(name)
index,
new_column_name_to_new_column_map.pop(name),
allow_broadcast=True,
method_name="with_columns",
)
)
else:
to_concat.append(self._native_frame[name])
to_concat.extend(
validate_dataframe_comparand(index, new_column_name_to_new_column_map[s])
validate_dataframe_comparand(
index,
new_column_name_to_new_column_map[s],
allow_broadcast=True,
method_name="with_columns",
)
for s in new_column_name_to_new_column_map
)

Expand Down
26 changes: 15 additions & 11 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,20 +461,17 @@ def __call__(self, df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:

plx = df.__narwhals_namespace__()
condition = parse_into_expr(self._condition, namespace=plx)(df)[0]

try:
value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
value_series = condition.__class__._from_iterable(
[self._then_value] * len(condition),
name="literal",
index=condition._native_series.index,
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
# `self._then_value` is a scalar and can't be converted to an expression
value_series = plx._create_series_from_scalar(
self._then_value, reference_series=condition
)
value_series_native, condition_native = broadcast_align_and_extract_native(
value_series, condition

condition_native, value_series_native = broadcast_align_and_extract_native(
condition, value_series
)

if self._otherwise_value is None:
Expand All @@ -494,7 +491,14 @@ def __call__(self, df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:
]
else:
otherwise_series = otherwise_expr(df)[0]
return [value_series.zip_with(condition, otherwise_series)]
_, otherwise_native = broadcast_align_and_extract_native(
condition, otherwise_series
)
return [
value_series._from_native_series(
value_series_native.where(condition_native, otherwise_native)
)
]

def then(self, value: PandasLikeExpr | PandasLikeSeries | Any) -> PandasThen:
self._then_value = value
Expand Down
24 changes: 22 additions & 2 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
native_to_narwhals_dtype as arrow_native_to_narwhals_dtype,
)
from narwhals.exceptions import ColumnNotFoundError
from narwhals.exceptions import ShapeError
from narwhals.utils import Implementation
from narwhals.utils import import_dtypes_module
from narwhals.utils import isinstance_or_issubclass
Expand Down Expand Up @@ -152,7 +153,9 @@ def broadcast_align_and_extract_native(
return lhs._native_series, rhs


def validate_dataframe_comparand(index: Any, other: Any) -> Any:
def validate_dataframe_comparand(
index: Any, other: Any, *, allow_broadcast: bool, method_name: str
) -> Any:
"""Validate RHS of binary operation.
If the comparison isn't supported, return `NotImplemented` so that the
Expand All @@ -164,10 +167,27 @@ def validate_dataframe_comparand(index: Any, other: Any) -> Any:
if isinstance(other, PandasLikeDataFrame):
return NotImplemented
if isinstance(other, PandasLikeSeries):
if other.len() == 1:
len_index = len(index)
len_other = other.len()

if len_other == 1:
if len_index > 1 and not allow_broadcast:
msg = (
f"{method_name}'s length: 1 differs from that of the series: "
f"{len_index}"
)
raise ShapeError(msg)
# broadcast
s = other._native_series
return s.__class__(s.iloc[0], index=index, dtype=s.dtype, name=s.name)

if len_index != len_other:
msg = (
f"{method_name}'s length: {len_other} differs from that of the series: "
f"{len_index}"
)
raise ShapeError(msg)

if other._native_series.index is not index:
return set_index(
other._native_series,
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def select(
new_columns_list = [col.alias(col_name) for col_name, col in new_columns.items()]
return self._from_native_frame(self._native_frame.select(*new_columns_list))

def filter(self, *predicates: SparkLikeExpr, **constraints: Any) -> Self:
def filter(self: Self, *predicates: SparkLikeExpr, **constraints: Any) -> Self:
plx = self.__narwhals_namespace__()
expr = plx.all_horizontal(
*chain(predicates, (plx.col(name) == v for name, v in constraints.items()))
Expand Down Expand Up @@ -183,7 +183,7 @@ def sort(
sort_cols = [sort_f(col) for col, sort_f in zip(flat_by, sort_funcs)]
return self._from_native_frame(self._native_frame.sort(*sort_cols))

def drop_nulls(self: Self, subset: str | list[str] | None) -> Self:
def drop_nulls(self: Self, subset: list[str] | None) -> Self:
return self._from_native_frame(self._native_frame.dropna(subset=subset))

def rename(self: Self, mapping: dict[str, str]) -> Self:
Expand Down
Loading

0 comments on commit e31deff

Please sign in to comment.