diff --git a/docs/api-reference/narwhals.md b/docs/api-reference/narwhals.md index b8ec2d793..c4b04a2f4 100644 --- a/docs/api-reference/narwhals.md +++ b/docs/api-reference/narwhals.md @@ -39,4 +39,5 @@ Here are the top-level functions available in Narwhals. - when - show_versions - to_native + - to_py_scalar show_source: false diff --git a/narwhals/__init__.py b/narwhals/__init__.py index aeba3ef5e..8dd76d081 100644 --- a/narwhals/__init__.py +++ b/narwhals/__init__.py @@ -58,6 +58,7 @@ from narwhals.translate import get_native_namespace from narwhals.translate import narwhalify from narwhals.translate import to_native +from narwhals.translate import to_py_scalar from narwhals.utils import is_ordered_categorical from narwhals.utils import maybe_align_index from narwhals.utils import maybe_convert_dtypes @@ -84,6 +85,7 @@ "maybe_reset_index", "maybe_set_index", "get_native_namespace", + "to_py_scalar", "all", "all_horizontal", "any_horizontal", diff --git a/narwhals/dependencies.py b/narwhals/dependencies.py index 144c57c8a..1f9ae19f5 100644 --- a/narwhals/dependencies.py +++ b/narwhals/dependencies.py @@ -46,6 +46,11 @@ def get_cudf() -> Any: return sys.modules.get("cudf", None) +def get_cupy() -> Any: + """Get cupy module (if already imported - else return None).""" + return sys.modules.get("cupy", None) + + def get_pyarrow() -> Any: # pragma: no cover """Get pyarrow module (if already imported - else return None).""" return sys.modules.get("pyarrow", None) diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index 7bcd6146e..c09b0f2b3 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -51,6 +51,7 @@ from narwhals.translate import _from_native_impl from narwhals.translate import get_native_namespace as nw_get_native_namespace from narwhals.translate import to_native +from narwhals.translate import to_py_scalar as nw_to_py_scalar from narwhals.typing import IntoDataFrameT from narwhals.typing import IntoFrameT from narwhals.typing import IntoSeriesT @@ -952,6 +953,28 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return decorator(func) +def to_py_scalar(scalar: Any) -> Any: + """If a scalar is not Python native, converts it to Python native. + + Raises: + ValueError: If the object is not convertible to a scalar. + + Examples: + >>> import narwhals.stable.v1 as nw + >>> import pandas as pd + >>> df = nw.from_native(pd.DataFrame({"a": [1, 2, 3]})) + >>> nw.to_py_scalar(df["a"].item(0)) + 1 + >>> import pyarrow as pa + >>> df = nw.from_native(pa.table({"a": [1, 2, 3]})) + >>> nw.to_py_scalar(df["a"].item(0)) + 1 + >>> nw.to_py_scalar(1) + 1 + """ + return _stableify(nw_to_py_scalar(scalar)) + + def all() -> Expr: """ Instantiate an expression representing all columns. @@ -2306,6 +2329,7 @@ def from_dict( "dependencies", "to_native", "from_native", + "to_py_scalar", "is_ordered_categorical", "maybe_align_index", "maybe_convert_dtypes", diff --git a/narwhals/translate.py b/narwhals/translate.py index 0dc0cd467..331b87d88 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -1,5 +1,8 @@ from __future__ import annotations +import numbers +from datetime import datetime +from datetime import timedelta from functools import wraps from typing import TYPE_CHECKING from typing import Any @@ -9,9 +12,11 @@ from typing import overload from narwhals.dependencies import get_cudf +from narwhals.dependencies import get_cupy from narwhals.dependencies import get_dask from narwhals.dependencies import get_dask_expr from narwhals.dependencies import get_modin +from narwhals.dependencies import get_numpy from narwhals.dependencies import get_pandas from narwhals.dependencies import get_polars from narwhals.dependencies import get_pyarrow @@ -776,8 +781,70 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return decorator(func) +def to_py_scalar(scalar_like: Any) -> Any: + """If a scalar is not Python native, converts it to Python native. + + Raises: + ValueError: If the object is not convertible to a scalar. + + Examples: + >>> import narwhals as nw + >>> import pandas as pd + >>> df = nw.from_native(pd.DataFrame({"a": [1, 2, 3]})) + >>> nw.to_py_scalar(df["a"].item(0)) + 1 + >>> import pyarrow as pa + >>> df = nw.from_native(pa.table({"a": [1, 2, 3]})) + >>> nw.to_py_scalar(df["a"].item(0)) + 1 + >>> nw.to_py_scalar(1) + 1 + """ + + pa = get_pyarrow() + if pa and isinstance(scalar_like, pa.Scalar): + return scalar_like.as_py() + + cupy = get_cupy() + if ( # pragma: no cover + cupy and isinstance(scalar_like, cupy.ndarray) and scalar_like.size == 1 + ): + return scalar_like.item() + + np = get_numpy() + if np and np.isscalar(scalar_like) and hasattr(scalar_like, "item"): + return scalar_like.item() + + pd = get_pandas() + if pd and isinstance(scalar_like, pd.Timestamp): + return scalar_like.to_pydatetime() + if pd and isinstance(scalar_like, pd.Timedelta): + return scalar_like.to_pytimedelta() + + all_scalar_types = ( + int, + float, + complex, + bool, + bytes, + str, + datetime, + timedelta, + numbers.Number, + ) + if isinstance(scalar_like, all_scalar_types): + return scalar_like + + msg = ( + f"Expected object convertible to a scalar, found {type(scalar_like)}. " + "Please report a bug to https://github.com/narwhals-dev/narwhals/issues" + ) + raise ValueError(msg) + + __all__ = [ "get_native_namespace", "to_native", "narwhalify", + "to_py_scalar", ] diff --git a/tests/translate/to_py_scalar_test.py b/tests/translate/to_py_scalar_test.py new file mode 100644 index 000000000..c9aa2749d --- /dev/null +++ b/tests/translate/to_py_scalar_test.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from datetime import datetime +from datetime import timedelta +from typing import TYPE_CHECKING +from typing import Any + +import numpy as np +import pandas as pd +import pytest + +import narwhals.stable.v1 as nw +from narwhals.dependencies import get_cudf + +if TYPE_CHECKING: + from tests.utils import ConstructorEager + + +@pytest.mark.parametrize( + ("input_value", "expected"), + [ + (1, 1), + (1.0, 1.0), + ("a", "a"), + (True, True), + (b"a", b"a"), + (datetime(2021, 1, 1), datetime(2021, 1, 1)), + (timedelta(days=1), timedelta(days=1)), + ], +) +def test_to_py_scalar( + constructor_eager: ConstructorEager, input_value: Any, expected: Any +) -> None: + df = nw.from_native(constructor_eager({"a": [input_value]})) + output = nw.to_py_scalar(df["a"].item(0)) + if expected == 1 and constructor_eager.__name__.startswith("pandas"): + assert not isinstance(output, np.int64) + elif isinstance(expected, datetime) and constructor_eager.__name__.startswith( + "pandas" + ): + assert not isinstance(output, pd.Timestamp) + elif isinstance(expected, timedelta) and constructor_eager.__name__.startswith( + "pandas" + ): + assert not isinstance(output, pd.Timedelta) + assert output == expected + + +@pytest.mark.parametrize( + "input_value", + [np.array([1, 2]), [1, 2, 3], {"a": [1, 2, 3]}], +) +def test_to_py_scalar_value_error(input_value: Any) -> None: + with pytest.raises(ValueError, match="Expected object convertible to a scalar"): + nw.to_py_scalar(input_value) + + +def test_to_py_scalar_value_error_cudf() -> None: + if cudf := get_cudf(): # pragma: no cover + df = nw.from_native(cudf.DataFrame({"a": [1, 2, 3]})) + + with pytest.raises(ValueError, match="Expected object convertible to a scalar"): + nw.to_py_scalar(df["a"])