From 6040d32b7d254a8ed3f2efe2ea433be14db772fb Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sat, 2 Nov 2024 14:40:39 +0100 Subject: [PATCH] check for polars specifically --- narwhals/_arrow/utils.py | 5 ++++- narwhals/_dask/utils.py | 5 ++++- narwhals/_pandas_like/utils.py | 5 ++++- narwhals/_polars/utils.py | 7 +++++-- narwhals/expr.py | 13 ------------- narwhals/series.py | 10 ---------- tests/expr_and_series/cast_test.py | 7 +++---- 7 files changed, 20 insertions(+), 32 deletions(-) diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index 1e0a4c0f9..56c8b1e50 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -4,6 +4,7 @@ from typing import Any from typing import Sequence +from narwhals.dependencies import get_polars from narwhals.utils import isinstance_or_issubclass if TYPE_CHECKING: @@ -76,7 +77,9 @@ def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType: def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any: - if "polars" in str(type(dtype)): + if (pl := get_polars()) is not None and isinstance( + dtype, (pl.DataType, pl.DataType.__class__) + ): msg = ( f"Expected Narwhals object, got: {type(dtype)}.\n\n" "Perhaps you:\n" diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py index 7e91d9b1e..cf8f9a3fc 100644 --- a/narwhals/_dask/utils.py +++ b/narwhals/_dask/utils.py @@ -4,6 +4,7 @@ from typing import Any from narwhals.dependencies import get_pandas +from narwhals.dependencies import get_polars from narwhals.dependencies import get_pyarrow from narwhals.utils import isinstance_or_issubclass from narwhals.utils import parse_version @@ -85,7 +86,9 @@ def validate_comparand(lhs: dask_expr.Series, rhs: dask_expr.Series) -> None: def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any: - if "polars" in str(type(dtype)): + if (pl := get_polars()) is not None and isinstance( + dtype, (pl.DataType, pl.DataType.__class__) + ): msg = ( f"Expected Narwhals object, got: {type(dtype)}.\n\n" "Perhaps you:\n" diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 99181bc1e..58123c565 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -10,6 +10,7 @@ from narwhals._arrow.utils import ( native_to_narwhals_dtype as arrow_native_to_narwhals_dtype, ) +from narwhals.dependencies import get_polars from narwhals.utils import Implementation from narwhals.utils import isinstance_or_issubclass @@ -339,7 +340,9 @@ def narwhals_to_native_dtype( # noqa: PLR0915 backend_version: tuple[int, ...], dtypes: DTypes, ) -> Any: - if "polars" in str(type(dtype)): + if (pl := get_polars()) is not None and isinstance( + dtype, (pl.DataType, pl.DataType.__class__) + ): msg = ( f"Expected Narwhals object, got: {type(dtype)}.\n\n" "Perhaps you:\n" diff --git a/narwhals/_polars/utils.py b/narwhals/_polars/utils.py index d5dbf2c2b..295c03bc3 100644 --- a/narwhals/_polars/utils.py +++ b/narwhals/_polars/utils.py @@ -8,6 +8,7 @@ from narwhals.dtypes import DType from narwhals.typing import DTypes +from narwhals.dependencies import get_polars from narwhals.utils import parse_version @@ -94,7 +95,9 @@ def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType: def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any: - if "polars" in str(type(dtype)): + if (pl := get_polars()) is not None and isinstance( + dtype, (pl.DataType, pl.DataType.__class__) + ): msg = ( f"Expected Narwhals object, got: {type(dtype)}.\n\n" "Perhaps you:\n" @@ -141,7 +144,7 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any: if dtype == dtypes.Datetime or isinstance(dtype, dtypes.Datetime): dt_time_unit = getattr(dtype, "time_unit", "us") dt_time_zone = getattr(dtype, "time_zone", None) - return pl.Datetime(dt_time_unit, dt_time_zone) # type: ignore[arg-type] + return pl.Datetime(dt_time_unit, dt_time_zone) if dtype == dtypes.Duration or isinstance(dtype, dtypes.Duration): du_time_unit: Literal["us", "ns", "ms"] = getattr(dtype, "time_unit", "us") return pl.Duration(time_unit=du_time_unit) diff --git a/narwhals/expr.py b/narwhals/expr.py index 034e908f7..46d44bee3 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -186,19 +186,6 @@ def cast(self: Self, dtype: DType | type[DType]) -> Self: foo: [[1,2,3]] bar: [[6,7,8]] """ - # from narwhals.dtypes import DType - - # if not ( - # isinstance(dtype, DType) - # or (isinstance(dtype, type) and issubclass(dtype, DType)) - # ): - # msg = ( - # f"Expected Narwhals DType, got: {type(dtype)}.\n\n" - # "Hint: Perhaps you used Polars DataType instance `pl.dtype` instead of " - # "Narwhals DType `nw.dtype`?" - # ) - # raise TypeError(msg) - return self.__class__( lambda plx: self._call(plx).cast(dtype), ) diff --git a/narwhals/series.py b/narwhals/series.py index b33a08dd5..add55897e 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -420,16 +420,6 @@ def cast(self: Self, dtype: DType | type[DType]) -> Self: 1 ] """ - # from narwhals.dtypes import DType - - # if not (isinstance(dtype, DType) or dtype == DType()): - # msg = ( - # f"Expected Narwhals DType, got: {type(dtype)}.\n\n" - # "Hint: Perhaps you used Polars DataType instance `pl.dtype` instead of " - # "Narwhals DType `nw.dtype`?" - # ) - # raise TypeError(msg) - return self._from_compliant_series(self._compliant_series.cast(dtype)) def to_frame(self) -> DataFrame[Any]: diff --git a/tests/expr_and_series/cast_test.py b/tests/expr_and_series/cast_test.py index 20ad5dbcc..14e77d68d 100644 --- a/tests/expr_and_series/cast_test.py +++ b/tests/expr_and_series/cast_test.py @@ -7,7 +7,6 @@ import pandas as pd import polars as pl -import pyarrow as pa import pytest import narwhals.stable.v1 as nw @@ -194,7 +193,7 @@ class Banana: pass with pytest.raises(AssertionError, match=r"Unknown dtype"): - df.select(nw.col("a").cast(Banana)) + df.select(nw.col("a").cast(Banana)) # type: ignore[arg-type] def test_cast_datetime_tz_aware( @@ -227,8 +226,8 @@ def test_cast_datetime_tz_aware( assert_equal_data(result, expected) -@pytest.mark.parametrize("dtype", [pl.String, pl.String(), pa.float64(), str]) -def test_raise_if_not_narwhals_dtype(constructor: Constructor, dtype: Any) -> None: +@pytest.mark.parametrize("dtype", [pl.String, pl.String()]) +def test_raise_if_polars_dtype(constructor: Constructor, dtype: Any) -> None: df = nw.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]})) with pytest.raises(TypeError, match="Expected Narwhals object, got:"): df.select(nw.col("a").cast(dtype))