Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add when-then-otherwise expression #588

Merged
merged 92 commits into from
Aug 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
d7446f4
add simple when
aivanoved Jul 17, 2024
6ebc78b
delete unnecessary file
aivanoved Jul 17, 2024
a3fdcc5
lint with ruff
aivanoved Jul 17, 2024
1ad1c94
use lambda expression
aivanoved Jul 17, 2024
55f394f
Merge branch 'main' into add-where-expression
aivanoved Jul 18, 2024
93e7121
remove deleted file
aivanoved Jul 18, 2024
f3770b7
Fix errors from the migration
aivanoved Jul 22, 2024
cf92f80
Merge branch 'main' into add-where-expression
aivanoved Jul 22, 2024
a7f442a
remove unnecessary changes
aivanoved Jul 22, 2024
7f23f05
add back the change in version
aivanoved Jul 22, 2024
7cc3aad
fix rename change
aivanoved Jul 22, 2024
ab85e40
rename test file
aivanoved Jul 22, 2024
4a8ac56
fix forgotten memeber change
aivanoved Jul 22, 2024
8283f24
make api identical
aivanoved Jul 22, 2024
f1c667e
remove unnecessary diff
aivanoved Jul 22, 2024
74937ea
add when documentation
aivanoved Jul 23, 2024
5b030d6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 23, 2024
7390e1a
address mypy issues
aivanoved Jul 23, 2024
279e3ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 23, 2024
63048ee
address ruff type-ignore blanket issue
aivanoved Jul 23, 2024
e96af89
support `Iterable[Expr]` in the pandas api
aivanoved Jul 23, 2024
d4f0e9c
move when test file to a better location
aivanoved Jul 23, 2024
99d9899
make when test filename similar to other tests
aivanoved Jul 23, 2024
71e542d
add simple when
aivanoved Jul 17, 2024
8b1355a
lint with ruff
aivanoved Jul 17, 2024
eb36164
use lambda expression
aivanoved Jul 17, 2024
c9b09bf
Fix errors from the migration
aivanoved Jul 22, 2024
2b1eabc
remove unnecessary changes
aivanoved Jul 22, 2024
2ef564e
remove unnecessary diff
aivanoved Jul 22, 2024
0e4773d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 23, 2024
151fe14
fix rebase error
aivanoved Jul 23, 2024
add7b89
remove files left from wrong rebase
aivanoved Jul 23, 2024
504c4ea
Merge remote-tracking branch 'upstream/main' into add-where-expression
aivanoved Jul 23, 2024
fd21c78
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 23, 2024
2f03bd0
chore: remove all wrong rebase leftover code
aivanoved Jul 23, 2024
37cc634
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 25, 2024
0454ac4
misc: keep api the same
aivanoved Jul 25, 2024
4ad28b7
test: add test for multiple predicates
aivanoved Jul 25, 2024
0ded393
misc: make when stable
aivanoved Jul 29, 2024
3280a3c
bug: make stable v1 `Then` a stable expr `Expr`
aivanoved Jul 29, 2024
5c6deed
bug: fix when constraints pandas implementation
aivanoved Jul 29, 2024
27d17d9
Merge remote-tracking branch 'upstream/main' into add-where-expression
aivanoved Jul 29, 2024
8688491
test: stabalise all paths and test error on no arg
aivanoved Jul 29, 2024
81039bf
misc: add when to main api
aivanoved Jul 29, 2024
1196fab
misc: remove constraints
aivanoved Jul 30, 2024
beba175
docs: remove wrong import in stable
aivanoved Jul 30, 2024
606feed
Merge remote-tracking branch 'upstream/main' into add-where-expression
aivanoved Jul 30, 2024
45684d4
docs: remove wrong import in main docstring
aivanoved Jul 30, 2024
21e3384
feat: add series and expression support for when then
aivanoved Jul 30, 2024
e2e0b92
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 30, 2024
ee2d9ba
misc: simplify condition
aivanoved Jul 30, 2024
e7f31bf
test: skip lazy test on series when
aivanoved Jul 30, 2024
c79a24c
misc: fix ruff check
aivanoved Jul 30, 2024
f1db4c4
misc: remove dead code
aivanoved Jul 30, 2024
06a3a97
misc: remove unused code
aivanoved Jul 31, 2024
bba70ea
misc: use available impl/backend version in object
aivanoved Jul 31, 2024
b02906d
Merge remote-tracking branch 'upstream/main' into add-where-expression
aivanoved Jul 31, 2024
561856f
Merge remote-tracking branch 'upstream/main' into add-where-expression
aivanoved Aug 1, 2024
e21243a
test: fix python 3.8 failing test
aivanoved Aug 6, 2024
ada9588
test: ignore dask
aivanoved Aug 7, 2024
ead6649
Update namespace.py
aivanoved Aug 8, 2024
bc129ee
simplify a bit
MarcoGorelli Aug 9, 2024
ba3d158
tests: clean up tests
aivanoved Aug 16, 2024
7cbc759
bug: keep `pandas` impl of `zip_with` type safe
aivanoved Aug 16, 2024
28f55fb
misc: add additional type safety in `pandas` impl of `when`
aivanoved Aug 16, 2024
47d6a35
Merge branch 'main' into add-where-expression
aivanoved Aug 16, 2024
ee835c2
misc: fix typos
aivanoved Aug 16, 2024
08e689a
misc: remove unneeded assignment
aivanoved Aug 16, 2024
d8bc8b7
bug: integer type casting is harder, revert 28f55fb
aivanoved Aug 16, 2024
affcd7f
tests: add otherwise iterable tests
aivanoved Aug 16, 2024
3b085cf
tests: disable failing modin
aivanoved Aug 16, 2024
1983a94
misc: remove unnecesary `zip_with` fix
aivanoved Aug 16, 2024
a2c6661
feat: `maybe_convert_dtypes` now can take as `Series`
aivanoved Aug 16, 2024
0201428
docs: update stable api
aivanoved Aug 16, 2024
e6d6117
misc: resolve old version issue
aivanoved Aug 16, 2024
dbf5b6e
mics: improve coverage, slow perf on maybe convert
aivanoved Aug 16, 2024
fde2e9a
misc: reformat to not use importlib
aivanoved Aug 16, 2024
3b37c5a
Merge branch 'main' into add-where-expression
aivanoved Aug 20, 2024
6f0db73
Merge branch 'main' into add-where-expression
aivanoved Aug 21, 2024
8efb260
misc: ignore coverage
aivanoved Aug 21, 2024
d738400
Merge remote-tracking branch 'upstream/main' into add-where-expression
MarcoGorelli Aug 23, 2024
0599200
wip
MarcoGorelli Aug 23, 2024
a90fbde
wip
MarcoGorelli Aug 23, 2024
e9b7b28
Merge remote-tracking branch 'upstream/main' into add-where-expression
MarcoGorelli Aug 23, 2024
3bb2629
wip
MarcoGorelli Aug 23, 2024
733effc
Merge remote-tracking branch 'upstream/main' into add-where-expression
MarcoGorelli Aug 23, 2024
eb81758
wip
MarcoGorelli Aug 23, 2024
9654147
fixup tests
MarcoGorelli Aug 23, 2024
2bc5cc0
nw.when
MarcoGorelli Aug 23, 2024
4b8a238
Merge remote-tracking branch 'upstream/main' into add-where-expression
MarcoGorelli Aug 24, 2024
54d34d7
drive-by: remove python_version from pyproject.toml for optional depe…
MarcoGorelli Aug 24, 2024
de4abfa
add test which covers into_expr
MarcoGorelli Aug 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference/narwhals.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Here are the top-level functions available in Narwhals.
- new_series
- sum
- sum_horizontal
- when
- show_versions
- to_native
show_source: false
2 changes: 2 additions & 0 deletions narwhals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from narwhals.expr import min
from narwhals.expr import sum
from narwhals.expr import sum_horizontal
from narwhals.expr import when
from narwhals.functions import concat
from narwhals.functions import from_dict
from narwhals.functions import get_level
Expand Down Expand Up @@ -79,6 +80,7 @@
"mean_horizontal",
"sum",
"sum_horizontal",
"when",
"DataFrame",
"LazyFrame",
"Series",
Expand Down
119 changes: 119 additions & 0 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any
from typing import Callable
from typing import Iterable
from typing import cast

from narwhals import dtypes
from narwhals._expression_parsing import parse_into_exprs
Expand Down Expand Up @@ -249,3 +250,121 @@ def concat(
backend_version=self._backend_version,
)
raise NotImplementedError

def when(
self,
*predicates: IntoPandasLikeExpr,
) -> PandasWhen:
plx = self.__class__(self._implementation, self._backend_version)
if predicates:
condition = plx.all_horizontal(*predicates)
else:
msg = "at least one predicate needs to be provided"
raise TypeError(msg)

return PandasWhen(condition, self._implementation, self._backend_version)


class PandasWhen:
def __init__(
self,
condition: PandasLikeExpr,
implementation: Implementation,
backend_version: tuple[int, ...],
then_value: Any = None,
otherwise_value: Any = None,
) -> None:
self._implementation = implementation
self._backend_version = backend_version
self._condition = condition
self._then_value = then_value
self._otherwise_value = otherwise_value

def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
from narwhals._expression_parsing import parse_into_expr
from narwhals._pandas_like.namespace import PandasLikeNamespace
from narwhals._pandas_like.utils import validate_column_comparand

plx = PandasLikeNamespace(
implementation=self._implementation, backend_version=self._backend_version
)

condition = parse_into_expr(self._condition, namespace=plx)._call(df)[0] # type: ignore[arg-type]
try:
value_series = parse_into_expr(self._then_value, namespace=plx)._call(df)[0] # type: ignore[arg-type]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
value_series = condition.__class__._from_iterable( # type: ignore[call-arg]
[self._then_value] * len(condition),
name="literal",
index=condition._native_series.index,
implementation=self._implementation,
backend_version=self._backend_version,
)
value_series = cast(PandasLikeSeries, value_series)

value_series_native = value_series._native_series
condition_native = validate_column_comparand(value_series_native.index, condition)

if self._otherwise_value is None:
return [
value_series._from_native_series(
value_series_native.where(condition_native)
)
]
try:
otherwise_series = parse_into_expr(
self._otherwise_value, namespace=plx
)._call(df)[0] # type: ignore[arg-type]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
return [
value_series._from_native_series(
value_series_native.where(condition_native, self._otherwise_value)
)
]
else:
return [value_series.zip_with(condition, otherwise_series)]

def then(self, value: PandasLikeExpr | PandasLikeSeries | Any) -> PandasThen:
self._then_value = value

return PandasThen(
self,
depth=0,
function_name="whenthen",
root_names=None,
output_names=None,
implementation=self._implementation,
backend_version=self._backend_version,
)


class PandasThen(PandasLikeExpr):
def __init__(
self,
call: PandasWhen,
*,
depth: int,
function_name: str,
root_names: list[str] | None,
output_names: list[str] | None,
implementation: Implementation,
backend_version: tuple[int, ...],
) -> None:
self._implementation = implementation
self._backend_version = backend_version

self._call = call
self._depth = depth
self._function_name = function_name
self._root_names = root_names
self._output_names = output_names

def otherwise(self, value: PandasLikeExpr | PandasLikeSeries | Any) -> PandasLikeExpr:
# type ignore because we are setting the `_call` attribute to a
# callable object of type `PandasWhen`, base class has the attribute as
# only a `Callable`
self._call._otherwise_value = value # type: ignore[attr-defined]
self._function_name = "whenotherwise"
return self
71 changes: 71 additions & 0 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3954,6 +3954,77 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
)


class When:
def __init__(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> None:
self._predicates = flatten([predicates])

def _extract_predicates(self, plx: Any) -> Any:
return [extract_compliant(plx, v) for v in self._predicates]

def then(self, value: Any) -> Then:
return Then(
lambda plx: plx.when(*self._extract_predicates(plx)).then(
extract_compliant(plx, value)
)
)


class Then(Expr):
def __init__(self, call: Callable[[Any], Any]) -> None:
self._call = call

def otherwise(self, value: Any) -> Expr:
return Expr(lambda plx: self._call(plx).otherwise(extract_compliant(plx, value)))


def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When:
"""
Start a `when-then-otherwise` expression.

Expression similar to an `if-else` statement in Python. Always initiated by a `pl.when(<condition>).then(<value if condition>)`., and optionally followed by chaining one or more `.when(<condition>).then(<value>)` statements.
Chained when-then operations should be read as Python `if, elif, ... elif` blocks, not as `if, if, ... if`, i.e. the first condition that evaluates to `True` will be picked.
If none of the conditions are `True`, an optional `.otherwise(<value if all statements are false>)` can be appended at the end. If not appended, and none of the conditions are `True`, `None` will be returned.

Arguments:
predicates: Condition(s) that must be met in order to apply the subsequent statement. Accepts one or more boolean expressions, which are implicitly combined with `&`. String input is parsed as a column name.

Examples:
>>> import pandas as pd
>>> import polars as pl
>>> import narwhals as nw
>>> df_pl = pl.DataFrame({"a": [1, 2, 3], "b": [5, 10, 15]})
>>> df_pd = pd.DataFrame({"a": [1, 2, 3], "b": [5, 10, 15]})

We define a dataframe-agnostic function:

>>> @nw.narwhalify
... def func(df_any):
... return df_any.with_columns(
... nw.when(nw.col("a") < 3).then(5).otherwise(6).alias("a_when")
... )

We can then pass either pandas or polars to `func`:

>>> func(df_pd)
a b a_when
0 1 5 5
1 2 10 5
2 3 15 6
>>> func(df_pl)
shape: (3, 3)
β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ a ┆ b ┆ a_when β”‚
β”‚ --- ┆ --- ┆ --- β”‚
β”‚ i64 ┆ i64 ┆ i32 β”‚
β•žβ•β•β•β•β•β•ͺ═════β•ͺ════════║
β”‚ 1 ┆ 5 ┆ 5 β”‚
β”‚ 2 ┆ 10 ┆ 5 β”‚
β”‚ 3 ┆ 15 ┆ 6 β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
return When(*predicates)


def all_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
r"""
Compute the bitwise AND horizontally across columns.
Expand Down
72 changes: 71 additions & 1 deletion narwhals/stable/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
from narwhals.dtypes import UInt64
from narwhals.dtypes import Unknown
from narwhals.expr import Expr as NwExpr
from narwhals.expr import Then as NwThen
from narwhals.expr import When as NwWhen
from narwhals.expr import when as nw_when
from narwhals.functions import concat
from narwhals.functions import show_versions
from narwhals.schema import Schema as NwSchema
Expand Down Expand Up @@ -1404,7 +1407,7 @@ def maybe_align_index(lhs: T, rhs: Series | DataFrame[Any] | LazyFrame[Any]) ->

def maybe_convert_dtypes(df: T, *args: bool, **kwargs: bool | str) -> T:
"""
Convert columns to the best possible dtypes using dtypes supporting ``pd.NA``, if df is pandas-like.
Convert columns or series to the best possible dtypes using dtypes supporting ``pd.NA``, if df is pandas-like.

Arguments:
obj: DataFrame or Series.
Expand Down Expand Up @@ -1493,6 +1496,72 @@ def get_level(
return nw.get_level(obj)


class When(NwWhen):
@classmethod
def from_when(cls, when: NwWhen) -> Self:
return cls(*when._predicates)

def then(self, value: Any) -> Then:
return Then.from_then(super().then(value))


class Then(NwThen, Expr):
@classmethod
def from_then(cls, then: NwThen) -> Self:
return cls(then._call)

def otherwise(self, value: Any) -> Expr:
return _stableify(super().otherwise(value))


def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When:
"""
Start a `when-then-otherwise` expression.

Expression similar to an `if-else` statement in Python. Always initiated by a `pl.when(<condition>).then(<value if condition>)`., and optionally followed by chaining one or more `.when(<condition>).then(<value>)` statements.
Chained when-then operations should be read as Python `if, elif, ... elif` blocks, not as `if, if, ... if`, i.e. the first condition that evaluates to `True` will be picked.
If none of the conditions are `True`, an optional `.otherwise(<value if all statements are false>)` can be appended at the end. If not appended, and none of the conditions are `True`, `None` will be returned.

Arguments:
predicates: Condition(s) that must be met in order to apply the subsequent statement. Accepts one or more boolean expressions, which are implicitly combined with `&`. String input is parsed as a column name.

Examples:
>>> import pandas as pd
>>> import polars as pl
>>> import narwhals.stable.v1 as nw
>>> df_pl = pl.DataFrame({"a": [1, 2, 3], "b": [5, 10, 15]})
>>> df_pd = pd.DataFrame({"a": [1, 2, 3], "b": [5, 10, 15]})

We define a dataframe-agnostic function:

>>> @nw.narwhalify
... def func(df_any):
... return df_any.with_columns(
... nw.when(nw.col("a") < 3).then(5).otherwise(6).alias("a_when")
... )

We can then pass either pandas or polars to `func`:

>>> func(df_pd)
a b a_when
0 1 5 5
1 2 10 5
2 3 15 6
>>> func(df_pl)
shape: (3, 3)
β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ a ┆ b ┆ a_when β”‚
β”‚ --- ┆ --- ┆ --- β”‚
β”‚ i64 ┆ i64 ┆ i32 β”‚
β•žβ•β•β•β•β•β•ͺ═════β•ͺ════════║
β”‚ 1 ┆ 5 ┆ 5 β”‚
β”‚ 2 ┆ 10 ┆ 5 β”‚
β”‚ 3 ┆ 15 ┆ 6 β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
return When.from_when(nw_when(*predicates))


def new_series(
name: str,
values: Any,
Expand Down Expand Up @@ -1624,6 +1693,7 @@ def from_dict(
"mean_horizontal",
"sum",
"sum_horizontal",
"when",
"DataFrame",
"LazyFrame",
"Series",
Expand Down
2 changes: 1 addition & 1 deletion narwhals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def maybe_set_index(df: T, column_names: str | list[str]) -> T:

def maybe_convert_dtypes(obj: T, *args: bool, **kwargs: bool | str) -> T:
"""
Convert columns to the best possible dtypes using dtypes supporting ``pd.NA``, if df is pandas-like.
Convert columns or series to the best possible dtypes using dtypes supporting ``pd.NA``, if df is pandas-like.

Arguments:
obj: DataFrame or Series.
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ exclude = [
]

[project.optional-dependencies]
cudf = ["cudf>=23.08.00; python_version >= '3.9'"]
cudf = ["cudf>=23.08.00"]
modin = ["modin"]
pandas = ["pandas>=0.25.3"]
polars = ["polars>=0.20.3"]
pyarrow = ["pyarrow>=11.0.0"]
dask = ["dask[dataframe]>=2024.7; python_version >= '3.9'"]
dask = ["dask[dataframe]>=2024.7"]

[project.urls]
"Homepage" = "https://github.com/narwhals-dev/narwhals"
Expand Down
Loading
Loading