From 70bc5c687924fa0ddfbb38b58644d92580ba6036 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Fri, 17 Jan 2025 14:47:02 +0100 Subject: [PATCH] feat: add `SparkExpr.cast` for basic types (#1812) * add narwhals_to_native_dtype * add cast * readability * pragma no cover * fix import in utils * set timezone for spark session to make tests reproducible * Update narwhals/_spark_like/utils.py Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> * bye bye datetime * add UnsupportedDTypeError --------- Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Co-authored-by: FBruzzesi --- narwhals/_spark_like/expr.py | 11 ++++++ narwhals/_spark_like/namespace.py | 22 +++++------ narwhals/_spark_like/utils.py | 53 ++++++++++++++++++++++++- narwhals/exceptions.py | 4 ++ tests/conftest.py | 2 + tests/expr_and_series/cast_test.py | 62 ++++++++++++++++-------------- tests/group_by_test.py | 6 --- 7 files changed, 114 insertions(+), 46 deletions(-) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 4958b2ba2..d2263c7c9 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -11,6 +11,7 @@ from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace from narwhals._spark_like.utils import get_column_name from narwhals._spark_like.utils import maybe_evaluate +from narwhals._spark_like.utils import narwhals_to_native_dtype from narwhals.typing import CompliantExpr from narwhals.utils import Implementation from narwhals.utils import parse_version @@ -21,6 +22,7 @@ from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals._spark_like.namespace import SparkLikeNamespace + from narwhals.dtypes import DType from narwhals.utils import Version @@ -285,6 +287,15 @@ def any(self) -> Self: return self._from_call(F.bool_or, "any", returns_scalar=True) + def cast(self: Self, dtype: DType | type[DType]) -> Self: + def _cast(_input: Column, dtype: DType | type[DType]) -> Column: + spark_dtype = narwhals_to_native_dtype(dtype, self._version) + return _input.cast(spark_dtype) + + return self._from_call( + _cast, "cast", dtype=dtype, returns_scalar=self._returns_scalar + ) + def count(self) -> Self: from pyspark.sql import functions as F # noqa: N812 diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index a8c778b86..bf876a88a 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -36,7 +36,7 @@ def _all(df: SparkLikeLazyFrame) -> list[Column]: return [F.col(col_name) for col_name in df.columns] - return SparkLikeExpr( # type: ignore[abstract] + return SparkLikeExpr( call=_all, depth=0, function_name="all", @@ -63,7 +63,7 @@ def _lit(_: SparkLikeLazyFrame) -> list[Column]: return [F.lit(value).alias("literal")] - return SparkLikeExpr( # type: ignore[abstract] + return SparkLikeExpr( call=_lit, depth=0, function_name="lit", @@ -81,7 +81,7 @@ def func(_: SparkLikeLazyFrame) -> list[Column]: return [F.count("*").alias("len")] - return SparkLikeExpr( # type: ignore[abstract] + return SparkLikeExpr( func, depth=0, function_name="len", @@ -101,7 +101,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: col_name = get_column_name(df, cols[0]) return [reduce(operator.and_, cols).alias(col_name)] - return SparkLikeExpr( # type: ignore[abstract] + return SparkLikeExpr( call=func, depth=max(x._depth for x in parsed_exprs) + 1, function_name="all_horizontal", @@ -121,7 +121,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: col_name = get_column_name(df, cols[0]) return [reduce(operator.or_, cols).alias(col_name)] - return SparkLikeExpr( # type: ignore[abstract] + return SparkLikeExpr( call=func, depth=max(x._depth for x in parsed_exprs) + 1, function_name="any_horizontal", @@ -148,7 +148,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: ).alias(col_name) ] - return SparkLikeExpr( # type: ignore[abstract] + return SparkLikeExpr( call=func, depth=max(x._depth for x in parsed_exprs) + 1, function_name="sum_horizontal", @@ -179,7 +179,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: ).alias(col_name) ] - return SparkLikeExpr( # type: ignore[abstract] + return SparkLikeExpr( call=func, depth=max(x._depth for x in parsed_exprs) + 1, function_name="mean_horizontal", @@ -201,7 +201,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: col_name = get_column_name(df, cols[0]) return [F.greatest(*cols).alias(col_name)] - return SparkLikeExpr( # type: ignore[abstract] + return SparkLikeExpr( call=func, depth=max(x._depth for x in parsed_exprs) + 1, function_name="max_horizontal", @@ -223,7 +223,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: col_name = get_column_name(df, cols[0]) return [F.least(*cols).alias(col_name)] - return SparkLikeExpr( # type: ignore[abstract] + return SparkLikeExpr( call=func, depth=max(x._depth for x in parsed_exprs) + 1, function_name="min_horizontal", @@ -320,7 +320,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return [result] - return SparkLikeExpr( # type: ignore[abstract] + return SparkLikeExpr( call=func, depth=max(x._depth for x in parsed_exprs) + 1, function_name="concat_str", @@ -392,7 +392,7 @@ def __call__(self, df: SparkLikeLazyFrame) -> list[Column]: def then(self, value: SparkLikeExpr | Any) -> SparkLikeThen: self._then_value = value - return SparkLikeThen( # type: ignore[abstract] + return SparkLikeThen( self, depth=0, function_name="whenthen", diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index fb3a3f3c4..73392e952 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -5,7 +5,9 @@ from typing import Any from narwhals.exceptions import InvalidIntoExprError +from narwhals.exceptions import UnsupportedDTypeError from narwhals.utils import import_dtypes_module +from narwhals.utils import isinstance_or_issubclass if TYPE_CHECKING: from pyspark.sql import Column @@ -22,9 +24,10 @@ def native_to_narwhals_dtype( dtype: pyspark_types.DataType, version: Version, ) -> DType: # pragma: no cover - dtypes = import_dtypes_module(version=version) from pyspark.sql import types as pyspark_types + dtypes = import_dtypes_module(version=version) + if isinstance(dtype, pyspark_types.DoubleType): return dtypes.Float64() if isinstance(dtype, pyspark_types.FloatType): @@ -35,6 +38,8 @@ def native_to_narwhals_dtype( return dtypes.Int32() if isinstance(dtype, pyspark_types.ShortType): return dtypes.Int16() + if isinstance(dtype, pyspark_types.ByteType): + return dtypes.Int8() string_types = [ pyspark_types.StringType, pyspark_types.VarcharType, @@ -58,6 +63,52 @@ def native_to_narwhals_dtype( return dtypes.Unknown() +def narwhals_to_native_dtype( + dtype: DType | type[DType], version: Version +) -> pyspark_types.DataType: + from pyspark.sql import types as pyspark_types + + dtypes = import_dtypes_module(version) + + if isinstance_or_issubclass(dtype, dtypes.Float64): + return pyspark_types.DoubleType() + if isinstance_or_issubclass(dtype, dtypes.Float32): + return pyspark_types.FloatType() + if isinstance_or_issubclass(dtype, dtypes.Int64): + return pyspark_types.LongType() + if isinstance_or_issubclass(dtype, dtypes.Int32): + return pyspark_types.IntegerType() + if isinstance_or_issubclass(dtype, dtypes.Int16): + return pyspark_types.ShortType() + if isinstance_or_issubclass(dtype, dtypes.Int8): + return pyspark_types.ByteType() + if isinstance_or_issubclass(dtype, dtypes.String): + return pyspark_types.StringType() + if isinstance_or_issubclass(dtype, dtypes.Boolean): + return pyspark_types.BooleanType() + if any(isinstance_or_issubclass(dtype, t) for t in [dtypes.Date, dtypes.Datetime]): + msg = "Converting to Date or Datetime dtype is not supported yet" + raise NotImplementedError(msg) + if isinstance_or_issubclass(dtype, dtypes.List): # pragma: no cover + msg = "Converting to List dtype is not supported yet" + raise NotImplementedError(msg) + if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover + msg = "Converting to Struct dtype is not supported yet" + raise NotImplementedError(msg) + if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover + msg = "Converting to Array dtype is not supported yet" + raise NotImplementedError(msg) + if any( + isinstance_or_issubclass(dtype, t) + for t in [dtypes.UInt64, dtypes.UInt32, dtypes.UInt16, dtypes.UInt8] + ): # pragma: no cover + msg = "Unsigned integer types are not supported by PySpark" + raise UnsupportedDTypeError(msg) + + msg = f"Unknown dtype: {dtype}" # pragma: no cover + raise AssertionError(msg) + + def get_column_name(df: SparkLikeLazyFrame, column: Column) -> str: return str(df._native_frame.select(column).columns[0]) diff --git a/narwhals/exceptions.py b/narwhals/exceptions.py index 18a225c8e..61447e54f 100644 --- a/narwhals/exceptions.py +++ b/narwhals/exceptions.py @@ -83,5 +83,9 @@ def from_expr_name(cls, expr_name: str) -> AnonymousExprError: return AnonymousExprError(message) +class UnsupportedDTypeError(ValueError): + """Exception raised when trying to convert to a DType which is not supported by the given backend.""" + + class NarwhalsUnstableWarning(UserWarning): """Warning issued when a method or function is considered unstable in the stable api.""" diff --git a/tests/conftest.py b/tests/conftest.py index 60dec8815..95b969e95 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -152,6 +152,8 @@ def pyspark_lazy_constructor() -> Callable[[Any], IntoFrame]: # pragma: no cove # executing one task at a time makes the tests faster .config("spark.default.parallelism", "1") .config("spark.sql.shuffle.partitions", "2") + # common timezone for all tests environments + .config("spark.sql.session.timeZone", "UTC") .getOrCreate() ) diff --git a/tests/expr_and_series/cast_test.py b/tests/expr_and_series/cast_test.py index ba2b82493..6dbda7901 100644 --- a/tests/expr_and_series/cast_test.py +++ b/tests/expr_and_series/cast_test.py @@ -17,7 +17,7 @@ from tests.utils import assert_equal_data from tests.utils import is_windows -data = { +DATA = { "a": [1], "b": [1], "c": [1], @@ -35,7 +35,7 @@ "o": ["a"], "p": [1], } -schema = { +SCHEMA = { "a": nw.Int64, "b": nw.Int32, "c": nw.Int16, @@ -54,13 +54,15 @@ "p": nw.Int64, } +SPARK_INCOMPATIBLE_COLUMNS = {"e", "f", "g", "h", "l", "o", "p"} + @pytest.mark.filterwarnings("ignore:casting period[M] values to int64:FutureWarning") def test_cast( constructor: Constructor, request: pytest.FixtureRequest, ) -> None: - if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): + if "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) if "pyarrow_table_constructor" in str(constructor) and PYARROW_VERSION <= ( 15, @@ -69,28 +71,20 @@ def test_cast( if "modin_constructor" in str(constructor): # TODO(unassigned): in modin, we end up with `' None: def test_cast_raises_for_unknown_dtype( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): + if "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) if "pyarrow_table" in str(constructor) and PYARROW_VERSION < (15,): # Unsupported cast from string to dictionary using function cast_dictionary request.applymarker(pytest.mark.xfail) + + if "pyspark" in str(constructor): + incompatible_columns = SPARK_INCOMPATIBLE_COLUMNS # pragma: no cover + else: + incompatible_columns = set() + + data = {k: v for k, v in DATA.items() if k not in incompatible_columns} + schema = {k: v for k, v in SCHEMA.items() if k not in incompatible_columns} + df = nw.from_native(constructor(data)).select( nw.col(key).cast(value) for key, value in schema.items() ) @@ -204,7 +210,7 @@ def test_cast_datetime_tz_aware( or "duckdb" in str(constructor) or "cudf" in str(constructor) # https://github.com/rapidsai/cudf/issues/16973 or ("pyarrow_table" in str(constructor) and is_windows()) - or ("pyspark" in str(constructor)) + or "pyspark" in str(constructor) ): request.applymarker(pytest.mark.xfail) diff --git a/tests/group_by_test.py b/tests/group_by_test.py index 40acd142b..d446f8003 100644 --- a/tests/group_by_test.py +++ b/tests/group_by_test.py @@ -280,9 +280,6 @@ def test_key_with_nulls( # TODO(unassigned): Modin flaky here? request.applymarker(pytest.mark.skip) - if "pyspark" in str(constructor): - request.applymarker(pytest.mark.xfail) - context = ( pytest.raises(NotImplementedError, match="null values") if ("pandas_constructor" in str(constructor) and PANDAS_VERSION < (1, 1, 0)) @@ -307,9 +304,6 @@ def test_key_with_nulls_ignored( if "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) - if "pyspark" in str(constructor): - request.applymarker(pytest.mark.xfail) - data = {"b": [4, 5, None], "a": [1, 2, 3]} result = ( nw.from_native(constructor(data))