Skip to content

Commit

Permalink
simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Jan 16, 2025
1 parent 64037f3 commit 0283df1
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 30 deletions.
3 changes: 1 addition & 2 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,7 @@ def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]:
try:
value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
plx = condition.__narwhals_namespace__()
# `self._then_value` is a scalar and can't be converted to an expression
value_series = plx._create_series_from_scalar(
self._then_value, reference_series=condition
)
Expand Down
15 changes: 9 additions & 6 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,16 +402,19 @@ def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
try:
then_expr = parse_into_expr(self._then_value, namespace=plx)
value_series = then_expr(df)[0]

# literal or reduction case
if then_expr._returns_scalar: # type: ignore[attr-defined]
_df = condition.to_frame("a")
_df["tmp"] = value_series[0]
value_series = _df["tmp"]
is_scalar = then_expr._returns_scalar # type: ignore[attr-defined]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
# `self._then_value` is a scalar and can't be converted to an expression
value_series = [self._then_value]
is_scalar = True

if is_scalar:
_df = condition.to_frame("a")
_df["tmp"] = self._then_value
_df["tmp"] = value_series[0]
value_series = _df["tmp"]

value_series = cast("dx.Series", value_series)
validate_comparand(condition, value_series)

Expand Down
40 changes: 18 additions & 22 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,29 +461,17 @@ def __call__(self, df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:

plx = df.__narwhals_namespace__()
condition = parse_into_expr(self._condition, namespace=plx)(df)[0]

try:
value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0]
if len(value_series) == 1: # literal or reduction case
value_series = condition.__class__._from_iterable(
[value_series[0]] * len(condition),
name="literal",
index=condition._native_series.index,
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
)
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
value_series = condition.__class__._from_iterable(
[self._then_value] * len(condition),
name="literal",
index=condition._native_series.index,
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
# `self._then_value` is a scalar and can't be converted to an expression
value_series = plx._create_series_from_scalar(
self._then_value, reference_series=condition
)
value_series_native, condition_native = broadcast_align_and_extract_native(
value_series, condition

condition_native, value_series_native = broadcast_align_and_extract_native(
condition, value_series
)

if self._otherwise_value is None:
Expand All @@ -493,7 +481,9 @@ def __call__(self, df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:
)
]
try:
otherwise_expr = parse_into_expr(self._otherwise_value, namespace=plx)
otherwise_series = parse_into_expr(self._otherwise_value, namespace=plx)(df)[
0
]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
return [
Expand All @@ -502,8 +492,14 @@ def __call__(self, df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:
)
]
else:
otherwise_series = otherwise_expr(df)[0]
return [value_series.zip_with(condition, otherwise_series)]
_, otherwise_native = broadcast_align_and_extract_native(
condition, otherwise_series
)
return [
value_series._from_native_series(
value_series_native.where(condition_native, otherwise_native)
)
]

def then(self, value: PandasLikeExpr | PandasLikeSeries | Any) -> PandasThen:
self._then_value = value
Expand Down

0 comments on commit 0283df1

Please sign in to comment.