Skip to content

Commit

Permalink
feat: add when-then-otherwise expression (#588)
Browse files Browse the repository at this point in the history
  • Loading branch information
aivanoved authored Aug 24, 2024
1 parent c072b5d commit 83dd6e1
Show file tree
Hide file tree
Showing 8 changed files with 432 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/api-reference/narwhals.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Here are the top-level functions available in Narwhals.
- new_series
- sum
- sum_horizontal
- when
- show_versions
- to_native
show_source: false
2 changes: 2 additions & 0 deletions narwhals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from narwhals.expr import min
from narwhals.expr import sum
from narwhals.expr import sum_horizontal
from narwhals.expr import when
from narwhals.functions import concat
from narwhals.functions import from_dict
from narwhals.functions import get_level
Expand Down Expand Up @@ -79,6 +80,7 @@
"mean_horizontal",
"sum",
"sum_horizontal",
"when",
"DataFrame",
"LazyFrame",
"Series",
Expand Down
119 changes: 119 additions & 0 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any
from typing import Callable
from typing import Iterable
from typing import cast

from narwhals import dtypes
from narwhals._expression_parsing import parse_into_exprs
Expand Down Expand Up @@ -249,3 +250,121 @@ def concat(
backend_version=self._backend_version,
)
raise NotImplementedError

def when(
self,
*predicates: IntoPandasLikeExpr,
) -> PandasWhen:
plx = self.__class__(self._implementation, self._backend_version)
if predicates:
condition = plx.all_horizontal(*predicates)
else:
msg = "at least one predicate needs to be provided"
raise TypeError(msg)

return PandasWhen(condition, self._implementation, self._backend_version)


class PandasWhen:
def __init__(
self,
condition: PandasLikeExpr,
implementation: Implementation,
backend_version: tuple[int, ...],
then_value: Any = None,
otherwise_value: Any = None,
) -> None:
self._implementation = implementation
self._backend_version = backend_version
self._condition = condition
self._then_value = then_value
self._otherwise_value = otherwise_value

def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
from narwhals._expression_parsing import parse_into_expr
from narwhals._pandas_like.namespace import PandasLikeNamespace
from narwhals._pandas_like.utils import validate_column_comparand

plx = PandasLikeNamespace(
implementation=self._implementation, backend_version=self._backend_version
)

condition = parse_into_expr(self._condition, namespace=plx)._call(df)[0] # type: ignore[arg-type]
try:
value_series = parse_into_expr(self._then_value, namespace=plx)._call(df)[0] # type: ignore[arg-type]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
value_series = condition.__class__._from_iterable( # type: ignore[call-arg]
[self._then_value] * len(condition),
name="literal",
index=condition._native_series.index,
implementation=self._implementation,
backend_version=self._backend_version,
)
value_series = cast(PandasLikeSeries, value_series)

value_series_native = value_series._native_series
condition_native = validate_column_comparand(value_series_native.index, condition)

if self._otherwise_value is None:
return [
value_series._from_native_series(
value_series_native.where(condition_native)
)
]
try:
otherwise_series = parse_into_expr(
self._otherwise_value, namespace=plx
)._call(df)[0] # type: ignore[arg-type]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
return [
value_series._from_native_series(
value_series_native.where(condition_native, self._otherwise_value)
)
]
else:
return [value_series.zip_with(condition, otherwise_series)]

def then(self, value: PandasLikeExpr | PandasLikeSeries | Any) -> PandasThen:
self._then_value = value

return PandasThen(
self,
depth=0,
function_name="whenthen",
root_names=None,
output_names=None,
implementation=self._implementation,
backend_version=self._backend_version,
)


class PandasThen(PandasLikeExpr):
def __init__(
self,
call: PandasWhen,
*,
depth: int,
function_name: str,
root_names: list[str] | None,
output_names: list[str] | None,
implementation: Implementation,
backend_version: tuple[int, ...],
) -> None:
self._implementation = implementation
self._backend_version = backend_version

self._call = call
self._depth = depth
self._function_name = function_name
self._root_names = root_names
self._output_names = output_names

def otherwise(self, value: PandasLikeExpr | PandasLikeSeries | Any) -> PandasLikeExpr:
# type ignore because we are setting the `_call` attribute to a
# callable object of type `PandasWhen`, base class has the attribute as
# only a `Callable`
self._call._otherwise_value = value # type: ignore[attr-defined]
self._function_name = "whenotherwise"
return self
71 changes: 71 additions & 0 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3954,6 +3954,77 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
)


class When:
def __init__(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> None:
self._predicates = flatten([predicates])

def _extract_predicates(self, plx: Any) -> Any:
return [extract_compliant(plx, v) for v in self._predicates]

def then(self, value: Any) -> Then:
return Then(
lambda plx: plx.when(*self._extract_predicates(plx)).then(
extract_compliant(plx, value)
)
)


class Then(Expr):
def __init__(self, call: Callable[[Any], Any]) -> None:
self._call = call

def otherwise(self, value: Any) -> Expr:
return Expr(lambda plx: self._call(plx).otherwise(extract_compliant(plx, value)))


def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When:
"""
Start a `when-then-otherwise` expression.
Expression similar to an `if-else` statement in Python. Always initiated by a `pl.when(<condition>).then(<value if condition>)`., and optionally followed by chaining one or more `.when(<condition>).then(<value>)` statements.
Chained when-then operations should be read as Python `if, elif, ... elif` blocks, not as `if, if, ... if`, i.e. the first condition that evaluates to `True` will be picked.
If none of the conditions are `True`, an optional `.otherwise(<value if all statements are false>)` can be appended at the end. If not appended, and none of the conditions are `True`, `None` will be returned.
Arguments:
predicates: Condition(s) that must be met in order to apply the subsequent statement. Accepts one or more boolean expressions, which are implicitly combined with `&`. String input is parsed as a column name.
Examples:
>>> import pandas as pd
>>> import polars as pl
>>> import narwhals as nw
>>> df_pl = pl.DataFrame({"a": [1, 2, 3], "b": [5, 10, 15]})
>>> df_pd = pd.DataFrame({"a": [1, 2, 3], "b": [5, 10, 15]})
We define a dataframe-agnostic function:
>>> @nw.narwhalify
... def func(df_any):
... return df_any.with_columns(
... nw.when(nw.col("a") < 3).then(5).otherwise(6).alias("a_when")
... )
We can then pass either pandas or polars to `func`:
>>> func(df_pd)
a b a_when
0 1 5 5
1 2 10 5
2 3 15 6
>>> func(df_pl)
shape: (3, 3)
┌─────┬─────┬────────┐
│ a ┆ b ┆ a_when │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i32 │
╞═════╪═════╪════════╡
│ 1 ┆ 5 ┆ 5 │
│ 2 ┆ 10 ┆ 5 │
│ 3 ┆ 15 ┆ 6 │
└─────┴─────┴────────┘
"""
return When(*predicates)


def all_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
r"""
Compute the bitwise AND horizontally across columns.
Expand Down
72 changes: 71 additions & 1 deletion narwhals/stable/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
from narwhals.dtypes import UInt64
from narwhals.dtypes import Unknown
from narwhals.expr import Expr as NwExpr
from narwhals.expr import Then as NwThen
from narwhals.expr import When as NwWhen
from narwhals.expr import when as nw_when
from narwhals.functions import concat
from narwhals.functions import show_versions
from narwhals.schema import Schema as NwSchema
Expand Down Expand Up @@ -1404,7 +1407,7 @@ def maybe_align_index(lhs: T, rhs: Series | DataFrame[Any] | LazyFrame[Any]) ->

def maybe_convert_dtypes(df: T, *args: bool, **kwargs: bool | str) -> T:
"""
Convert columns to the best possible dtypes using dtypes supporting ``pd.NA``, if df is pandas-like.
Convert columns or series to the best possible dtypes using dtypes supporting ``pd.NA``, if df is pandas-like.
Arguments:
obj: DataFrame or Series.
Expand Down Expand Up @@ -1493,6 +1496,72 @@ def get_level(
return nw.get_level(obj)


class When(NwWhen):
@classmethod
def from_when(cls, when: NwWhen) -> Self:
return cls(*when._predicates)

def then(self, value: Any) -> Then:
return Then.from_then(super().then(value))


class Then(NwThen, Expr):
@classmethod
def from_then(cls, then: NwThen) -> Self:
return cls(then._call)

def otherwise(self, value: Any) -> Expr:
return _stableify(super().otherwise(value))


def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When:
"""
Start a `when-then-otherwise` expression.
Expression similar to an `if-else` statement in Python. Always initiated by a `pl.when(<condition>).then(<value if condition>)`., and optionally followed by chaining one or more `.when(<condition>).then(<value>)` statements.
Chained when-then operations should be read as Python `if, elif, ... elif` blocks, not as `if, if, ... if`, i.e. the first condition that evaluates to `True` will be picked.
If none of the conditions are `True`, an optional `.otherwise(<value if all statements are false>)` can be appended at the end. If not appended, and none of the conditions are `True`, `None` will be returned.
Arguments:
predicates: Condition(s) that must be met in order to apply the subsequent statement. Accepts one or more boolean expressions, which are implicitly combined with `&`. String input is parsed as a column name.
Examples:
>>> import pandas as pd
>>> import polars as pl
>>> import narwhals.stable.v1 as nw
>>> df_pl = pl.DataFrame({"a": [1, 2, 3], "b": [5, 10, 15]})
>>> df_pd = pd.DataFrame({"a": [1, 2, 3], "b": [5, 10, 15]})
We define a dataframe-agnostic function:
>>> @nw.narwhalify
... def func(df_any):
... return df_any.with_columns(
... nw.when(nw.col("a") < 3).then(5).otherwise(6).alias("a_when")
... )
We can then pass either pandas or polars to `func`:
>>> func(df_pd)
a b a_when
0 1 5 5
1 2 10 5
2 3 15 6
>>> func(df_pl)
shape: (3, 3)
┌─────┬─────┬────────┐
│ a ┆ b ┆ a_when │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i32 │
╞═════╪═════╪════════╡
│ 1 ┆ 5 ┆ 5 │
│ 2 ┆ 10 ┆ 5 │
│ 3 ┆ 15 ┆ 6 │
└─────┴─────┴────────┘
"""
return When.from_when(nw_when(*predicates))


def new_series(
name: str,
values: Any,
Expand Down Expand Up @@ -1624,6 +1693,7 @@ def from_dict(
"mean_horizontal",
"sum",
"sum_horizontal",
"when",
"DataFrame",
"LazyFrame",
"Series",
Expand Down
2 changes: 1 addition & 1 deletion narwhals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def maybe_set_index(df: T, column_names: str | list[str]) -> T:

def maybe_convert_dtypes(obj: T, *args: bool, **kwargs: bool | str) -> T:
"""
Convert columns to the best possible dtypes using dtypes supporting ``pd.NA``, if df is pandas-like.
Convert columns or series to the best possible dtypes using dtypes supporting ``pd.NA``, if df is pandas-like.
Arguments:
obj: DataFrame or Series.
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ exclude = [
]

[project.optional-dependencies]
cudf = ["cudf>=23.08.00; python_version >= '3.9'"]
cudf = ["cudf>=23.08.00"]
modin = ["modin"]
pandas = ["pandas>=0.25.3"]
polars = ["polars>=0.20.3"]
pyarrow = ["pyarrow>=11.0.0"]
dask = ["dask[dataframe]>=2024.7; python_version >= '3.9'"]
dask = ["dask[dataframe]>=2024.7"]

[project.urls]
"Homepage" = "https://github.com/narwhals-dev/narwhals"
Expand Down
Loading

0 comments on commit 83dd6e1

Please sign in to comment.