Skip to content

Commit

Permalink
feat: add to_py_scalar (#1194)
Browse files Browse the repository at this point in the history
* add to_py_scalar

* fix tests

* more fixes pragma and doctsting

* fix test_to_py_scalar_cudf_series

* convert numpy scalars

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove parse version

* simplify test_to_py_scalar_arrays_series

* add conversion for datetime and timedelta

* stricter to_py_scalar

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
EdAbati and pre-commit-ci[bot] authored Oct 19, 2024
1 parent 59aa483 commit 0c1650c
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/api-reference/narwhals.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@ Here are the top-level functions available in Narwhals.
- when
- show_versions
- to_native
- to_py_scalar
show_source: false
2 changes: 2 additions & 0 deletions narwhals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from narwhals.translate import get_native_namespace
from narwhals.translate import narwhalify
from narwhals.translate import to_native
from narwhals.translate import to_py_scalar
from narwhals.utils import is_ordered_categorical
from narwhals.utils import maybe_align_index
from narwhals.utils import maybe_convert_dtypes
Expand All @@ -84,6 +85,7 @@
"maybe_reset_index",
"maybe_set_index",
"get_native_namespace",
"to_py_scalar",
"all",
"all_horizontal",
"any_horizontal",
Expand Down
5 changes: 5 additions & 0 deletions narwhals/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def get_cudf() -> Any:
return sys.modules.get("cudf", None)


def get_cupy() -> Any:
"""Get cupy module (if already imported - else return None)."""
return sys.modules.get("cupy", None)


def get_pyarrow() -> Any: # pragma: no cover
"""Get pyarrow module (if already imported - else return None)."""
return sys.modules.get("pyarrow", None)
Expand Down
24 changes: 24 additions & 0 deletions narwhals/stable/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from narwhals.translate import _from_native_impl
from narwhals.translate import get_native_namespace as nw_get_native_namespace
from narwhals.translate import to_native
from narwhals.translate import to_py_scalar as nw_to_py_scalar
from narwhals.typing import IntoDataFrameT
from narwhals.typing import IntoFrameT
from narwhals.typing import IntoSeriesT
Expand Down Expand Up @@ -952,6 +953,28 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
return decorator(func)


def to_py_scalar(scalar: Any) -> Any:
"""If a scalar is not Python native, converts it to Python native.
Raises:
ValueError: If the object is not convertible to a scalar.
Examples:
>>> import narwhals.stable.v1 as nw
>>> import pandas as pd
>>> df = nw.from_native(pd.DataFrame({"a": [1, 2, 3]}))
>>> nw.to_py_scalar(df["a"].item(0))
1
>>> import pyarrow as pa
>>> df = nw.from_native(pa.table({"a": [1, 2, 3]}))
>>> nw.to_py_scalar(df["a"].item(0))
1
>>> nw.to_py_scalar(1)
1
"""
return _stableify(nw_to_py_scalar(scalar))


def all() -> Expr:
"""
Instantiate an expression representing all columns.
Expand Down Expand Up @@ -2306,6 +2329,7 @@ def from_dict(
"dependencies",
"to_native",
"from_native",
"to_py_scalar",
"is_ordered_categorical",
"maybe_align_index",
"maybe_convert_dtypes",
Expand Down
67 changes: 67 additions & 0 deletions narwhals/translate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import numbers
from datetime import datetime
from datetime import timedelta
from functools import wraps
from typing import TYPE_CHECKING
from typing import Any
Expand All @@ -9,9 +12,11 @@
from typing import overload

from narwhals.dependencies import get_cudf
from narwhals.dependencies import get_cupy
from narwhals.dependencies import get_dask
from narwhals.dependencies import get_dask_expr
from narwhals.dependencies import get_modin
from narwhals.dependencies import get_numpy
from narwhals.dependencies import get_pandas
from narwhals.dependencies import get_polars
from narwhals.dependencies import get_pyarrow
Expand Down Expand Up @@ -776,8 +781,70 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
return decorator(func)


def to_py_scalar(scalar_like: Any) -> Any:
"""If a scalar is not Python native, converts it to Python native.
Raises:
ValueError: If the object is not convertible to a scalar.
Examples:
>>> import narwhals as nw
>>> import pandas as pd
>>> df = nw.from_native(pd.DataFrame({"a": [1, 2, 3]}))
>>> nw.to_py_scalar(df["a"].item(0))
1
>>> import pyarrow as pa
>>> df = nw.from_native(pa.table({"a": [1, 2, 3]}))
>>> nw.to_py_scalar(df["a"].item(0))
1
>>> nw.to_py_scalar(1)
1
"""

pa = get_pyarrow()
if pa and isinstance(scalar_like, pa.Scalar):
return scalar_like.as_py()

cupy = get_cupy()
if ( # pragma: no cover
cupy and isinstance(scalar_like, cupy.ndarray) and scalar_like.size == 1
):
return scalar_like.item()

np = get_numpy()
if np and np.isscalar(scalar_like) and hasattr(scalar_like, "item"):
return scalar_like.item()

pd = get_pandas()
if pd and isinstance(scalar_like, pd.Timestamp):
return scalar_like.to_pydatetime()
if pd and isinstance(scalar_like, pd.Timedelta):
return scalar_like.to_pytimedelta()

all_scalar_types = (
int,
float,
complex,
bool,
bytes,
str,
datetime,
timedelta,
numbers.Number,
)
if isinstance(scalar_like, all_scalar_types):
return scalar_like

msg = (
f"Expected object convertible to a scalar, found {type(scalar_like)}. "
"Please report a bug to https://github.com/narwhals-dev/narwhals/issues"
)
raise ValueError(msg)


__all__ = [
"get_native_namespace",
"to_native",
"narwhalify",
"to_py_scalar",
]
63 changes: 63 additions & 0 deletions tests/translate/to_py_scalar_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations

from datetime import datetime
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import Any

import numpy as np
import pandas as pd
import pytest

import narwhals.stable.v1 as nw
from narwhals.dependencies import get_cudf

if TYPE_CHECKING:
from tests.utils import ConstructorEager


@pytest.mark.parametrize(
("input_value", "expected"),
[
(1, 1),
(1.0, 1.0),
("a", "a"),
(True, True),
(b"a", b"a"),
(datetime(2021, 1, 1), datetime(2021, 1, 1)),
(timedelta(days=1), timedelta(days=1)),
],
)
def test_to_py_scalar(
constructor_eager: ConstructorEager, input_value: Any, expected: Any
) -> None:
df = nw.from_native(constructor_eager({"a": [input_value]}))
output = nw.to_py_scalar(df["a"].item(0))
if expected == 1 and constructor_eager.__name__.startswith("pandas"):
assert not isinstance(output, np.int64)
elif isinstance(expected, datetime) and constructor_eager.__name__.startswith(
"pandas"
):
assert not isinstance(output, pd.Timestamp)
elif isinstance(expected, timedelta) and constructor_eager.__name__.startswith(
"pandas"
):
assert not isinstance(output, pd.Timedelta)
assert output == expected


@pytest.mark.parametrize(
"input_value",
[np.array([1, 2]), [1, 2, 3], {"a": [1, 2, 3]}],
)
def test_to_py_scalar_value_error(input_value: Any) -> None:
with pytest.raises(ValueError, match="Expected object convertible to a scalar"):
nw.to_py_scalar(input_value)


def test_to_py_scalar_value_error_cudf() -> None:
if cudf := get_cudf(): # pragma: no cover
df = nw.from_native(cudf.DataFrame({"a": [1, 2, 3]}))

with pytest.raises(ValueError, match="Expected object convertible to a scalar"):
nw.to_py_scalar(df["a"])

0 comments on commit 0c1650c

Please sign in to comment.