Skip to content

Commit

Permalink
feat: add SparkExpr.cast for basic types (#1812)
Browse files Browse the repository at this point in the history
* add narwhals_to_native_dtype

* add cast

* readability

* pragma no cover

* fix import in utils

* set timezone for spark session to make tests reproducible

* Update narwhals/_spark_like/utils.py

Co-authored-by: Francesco Bruzzesi <[email protected]>

* bye bye datetime

* add UnsupportedDTypeError

---------

Co-authored-by: Francesco Bruzzesi <[email protected]>
Co-authored-by: FBruzzesi <[email protected]>
  • Loading branch information
3 people authored Jan 17, 2025
1 parent 0bcb500 commit 70bc5c6
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 46 deletions.
11 changes: 11 additions & 0 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace
from narwhals._spark_like.utils import get_column_name
from narwhals._spark_like.utils import maybe_evaluate
from narwhals._spark_like.utils import narwhals_to_native_dtype
from narwhals.typing import CompliantExpr
from narwhals.utils import Implementation
from narwhals.utils import parse_version
Expand All @@ -21,6 +22,7 @@

from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.namespace import SparkLikeNamespace
from narwhals.dtypes import DType
from narwhals.utils import Version


Expand Down Expand Up @@ -285,6 +287,15 @@ def any(self) -> Self:

return self._from_call(F.bool_or, "any", returns_scalar=True)

def cast(self: Self, dtype: DType | type[DType]) -> Self:
def _cast(_input: Column, dtype: DType | type[DType]) -> Column:
spark_dtype = narwhals_to_native_dtype(dtype, self._version)
return _input.cast(spark_dtype)

return self._from_call(
_cast, "cast", dtype=dtype, returns_scalar=self._returns_scalar
)

def count(self) -> Self:
from pyspark.sql import functions as F # noqa: N812

Expand Down
22 changes: 11 additions & 11 deletions narwhals/_spark_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _all(df: SparkLikeLazyFrame) -> list[Column]:

return [F.col(col_name) for col_name in df.columns]

return SparkLikeExpr( # type: ignore[abstract]
return SparkLikeExpr(
call=_all,
depth=0,
function_name="all",
Expand All @@ -63,7 +63,7 @@ def _lit(_: SparkLikeLazyFrame) -> list[Column]:

return [F.lit(value).alias("literal")]

return SparkLikeExpr( # type: ignore[abstract]
return SparkLikeExpr(
call=_lit,
depth=0,
function_name="lit",
Expand All @@ -81,7 +81,7 @@ def func(_: SparkLikeLazyFrame) -> list[Column]:

return [F.count("*").alias("len")]

return SparkLikeExpr( # type: ignore[abstract]
return SparkLikeExpr(
func,
depth=0,
function_name="len",
Expand All @@ -101,7 +101,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
col_name = get_column_name(df, cols[0])
return [reduce(operator.and_, cols).alias(col_name)]

return SparkLikeExpr( # type: ignore[abstract]
return SparkLikeExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="all_horizontal",
Expand All @@ -121,7 +121,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
col_name = get_column_name(df, cols[0])
return [reduce(operator.or_, cols).alias(col_name)]

return SparkLikeExpr( # type: ignore[abstract]
return SparkLikeExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="any_horizontal",
Expand All @@ -148,7 +148,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
).alias(col_name)
]

return SparkLikeExpr( # type: ignore[abstract]
return SparkLikeExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="sum_horizontal",
Expand Down Expand Up @@ -179,7 +179,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
).alias(col_name)
]

return SparkLikeExpr( # type: ignore[abstract]
return SparkLikeExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="mean_horizontal",
Expand All @@ -201,7 +201,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
col_name = get_column_name(df, cols[0])
return [F.greatest(*cols).alias(col_name)]

return SparkLikeExpr( # type: ignore[abstract]
return SparkLikeExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="max_horizontal",
Expand All @@ -223,7 +223,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
col_name = get_column_name(df, cols[0])
return [F.least(*cols).alias(col_name)]

return SparkLikeExpr( # type: ignore[abstract]
return SparkLikeExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="min_horizontal",
Expand Down Expand Up @@ -320,7 +320,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:

return [result]

return SparkLikeExpr( # type: ignore[abstract]
return SparkLikeExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="concat_str",
Expand Down Expand Up @@ -392,7 +392,7 @@ def __call__(self, df: SparkLikeLazyFrame) -> list[Column]:
def then(self, value: SparkLikeExpr | Any) -> SparkLikeThen:
self._then_value = value

return SparkLikeThen( # type: ignore[abstract]
return SparkLikeThen(
self,
depth=0,
function_name="whenthen",
Expand Down
53 changes: 52 additions & 1 deletion narwhals/_spark_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from typing import Any

from narwhals.exceptions import InvalidIntoExprError
from narwhals.exceptions import UnsupportedDTypeError
from narwhals.utils import import_dtypes_module
from narwhals.utils import isinstance_or_issubclass

if TYPE_CHECKING:
from pyspark.sql import Column
Expand All @@ -22,9 +24,10 @@ def native_to_narwhals_dtype(
dtype: pyspark_types.DataType,
version: Version,
) -> DType: # pragma: no cover
dtypes = import_dtypes_module(version=version)
from pyspark.sql import types as pyspark_types

dtypes = import_dtypes_module(version=version)

if isinstance(dtype, pyspark_types.DoubleType):
return dtypes.Float64()
if isinstance(dtype, pyspark_types.FloatType):
Expand All @@ -35,6 +38,8 @@ def native_to_narwhals_dtype(
return dtypes.Int32()
if isinstance(dtype, pyspark_types.ShortType):
return dtypes.Int16()
if isinstance(dtype, pyspark_types.ByteType):
return dtypes.Int8()
string_types = [
pyspark_types.StringType,
pyspark_types.VarcharType,
Expand All @@ -58,6 +63,52 @@ def native_to_narwhals_dtype(
return dtypes.Unknown()


def narwhals_to_native_dtype(
dtype: DType | type[DType], version: Version
) -> pyspark_types.DataType:
from pyspark.sql import types as pyspark_types

dtypes = import_dtypes_module(version)

if isinstance_or_issubclass(dtype, dtypes.Float64):
return pyspark_types.DoubleType()
if isinstance_or_issubclass(dtype, dtypes.Float32):
return pyspark_types.FloatType()
if isinstance_or_issubclass(dtype, dtypes.Int64):
return pyspark_types.LongType()
if isinstance_or_issubclass(dtype, dtypes.Int32):
return pyspark_types.IntegerType()
if isinstance_or_issubclass(dtype, dtypes.Int16):
return pyspark_types.ShortType()
if isinstance_or_issubclass(dtype, dtypes.Int8):
return pyspark_types.ByteType()
if isinstance_or_issubclass(dtype, dtypes.String):
return pyspark_types.StringType()
if isinstance_or_issubclass(dtype, dtypes.Boolean):
return pyspark_types.BooleanType()
if any(isinstance_or_issubclass(dtype, t) for t in [dtypes.Date, dtypes.Datetime]):
msg = "Converting to Date or Datetime dtype is not supported yet"
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.List): # pragma: no cover
msg = "Converting to List dtype is not supported yet"
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover
msg = "Converting to Struct dtype is not supported yet"
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
msg = "Converting to Array dtype is not supported yet"
raise NotImplementedError(msg)
if any(
isinstance_or_issubclass(dtype, t)
for t in [dtypes.UInt64, dtypes.UInt32, dtypes.UInt16, dtypes.UInt8]
): # pragma: no cover
msg = "Unsigned integer types are not supported by PySpark"
raise UnsupportedDTypeError(msg)

msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)


def get_column_name(df: SparkLikeLazyFrame, column: Column) -> str:
return str(df._native_frame.select(column).columns[0])

Expand Down
4 changes: 4 additions & 0 deletions narwhals/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,9 @@ def from_expr_name(cls, expr_name: str) -> AnonymousExprError:
return AnonymousExprError(message)


class UnsupportedDTypeError(ValueError):
"""Exception raised when trying to convert to a DType which is not supported by the given backend."""


class NarwhalsUnstableWarning(UserWarning):
"""Warning issued when a method or function is considered unstable in the stable api."""
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def pyspark_lazy_constructor() -> Callable[[Any], IntoFrame]: # pragma: no cove
# executing one task at a time makes the tests faster
.config("spark.default.parallelism", "1")
.config("spark.sql.shuffle.partitions", "2")
# common timezone for all tests environments
.config("spark.sql.session.timeZone", "UTC")
.getOrCreate()
)

Expand Down
62 changes: 34 additions & 28 deletions tests/expr_and_series/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from tests.utils import assert_equal_data
from tests.utils import is_windows

data = {
DATA = {
"a": [1],
"b": [1],
"c": [1],
Expand All @@ -35,7 +35,7 @@
"o": ["a"],
"p": [1],
}
schema = {
SCHEMA = {
"a": nw.Int64,
"b": nw.Int32,
"c": nw.Int16,
Expand All @@ -54,13 +54,15 @@
"p": nw.Int64,
}

SPARK_INCOMPATIBLE_COLUMNS = {"e", "f", "g", "h", "l", "o", "p"}


@pytest.mark.filterwarnings("ignore:casting period[M] values to int64:FutureWarning")
def test_cast(
constructor: Constructor,
request: pytest.FixtureRequest,
) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
if "pyarrow_table_constructor" in str(constructor) and PYARROW_VERSION <= (
15,
Expand All @@ -69,28 +71,20 @@ def test_cast(
if "modin_constructor" in str(constructor):
# TODO(unassigned): in modin, we end up with `'<U0'` dtype
request.applymarker(pytest.mark.xfail)

if "pyspark" in str(constructor):
incompatible_columns = SPARK_INCOMPATIBLE_COLUMNS # pragma: no cover
else:
incompatible_columns = set()

data = {c: v for c, v in DATA.items() if c not in incompatible_columns}
schema = {c: t for c, t in SCHEMA.items() if c not in incompatible_columns}

df = nw.from_native(constructor(data)).select(
nw.col(key).cast(value) for key, value in schema.items()
)
result = df.select(
nw.col("a").cast(nw.Int32),
nw.col("b").cast(nw.Int16),
nw.col("c").cast(nw.Int8),
nw.col("d").cast(nw.Int64),
nw.col("e").cast(nw.UInt32),
nw.col("f").cast(nw.UInt16),
nw.col("g").cast(nw.UInt8),
nw.col("h").cast(nw.UInt64),
nw.col("i").cast(nw.Float32),
nw.col("j").cast(nw.Float64),
nw.col("k").cast(nw.String),
nw.col("l").cast(nw.Datetime),
nw.col("m").cast(nw.Int8),
nw.col("n").cast(nw.Int8),
nw.col("o").cast(nw.String),
nw.col("p").cast(nw.Duration),
nw.col(col_).cast(dtype) for col_, dtype in schema.items()
)
expected = {

cast_map = {
"a": nw.Int32,
"b": nw.Int16,
"c": nw.Int8,
Expand All @@ -108,7 +102,10 @@ def test_cast(
"o": nw.String,
"p": nw.Duration,
}
assert dict(result.collect_schema()) == expected
cast_map = {c: t for c, t in cast_map.items() if c not in incompatible_columns}

result = df.select(*[nw.col(col_).cast(dtype) for col_, dtype in cast_map.items()])
assert dict(result.collect_schema()) == cast_map


def test_cast_series(
Expand All @@ -123,8 +120,8 @@ def test_cast_series(
# TODO(unassigned): in modin, we end up with `'<U0'` dtype
request.applymarker(pytest.mark.xfail)
df = (
nw.from_native(constructor_eager(data))
.select(nw.col(key).cast(value) for key, value in schema.items())
nw.from_native(constructor_eager(DATA))
.select(nw.col(key).cast(value) for key, value in SCHEMA.items())
.lazy()
.collect()
)
Expand Down Expand Up @@ -180,11 +177,20 @@ def test_cast_string() -> None:
def test_cast_raises_for_unknown_dtype(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
if "pyarrow_table" in str(constructor) and PYARROW_VERSION < (15,):
# Unsupported cast from string to dictionary using function cast_dictionary
request.applymarker(pytest.mark.xfail)

if "pyspark" in str(constructor):
incompatible_columns = SPARK_INCOMPATIBLE_COLUMNS # pragma: no cover
else:
incompatible_columns = set()

data = {k: v for k, v in DATA.items() if k not in incompatible_columns}
schema = {k: v for k, v in SCHEMA.items() if k not in incompatible_columns}

df = nw.from_native(constructor(data)).select(
nw.col(key).cast(value) for key, value in schema.items()
)
Expand All @@ -204,7 +210,7 @@ def test_cast_datetime_tz_aware(
or "duckdb" in str(constructor)
or "cudf" in str(constructor) # https://github.com/rapidsai/cudf/issues/16973
or ("pyarrow_table" in str(constructor) and is_windows())
or ("pyspark" in str(constructor))
or "pyspark" in str(constructor)
):
request.applymarker(pytest.mark.xfail)

Expand Down
6 changes: 0 additions & 6 deletions tests/group_by_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,6 @@ def test_key_with_nulls(
# TODO(unassigned): Modin flaky here?
request.applymarker(pytest.mark.skip)

if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

context = (
pytest.raises(NotImplementedError, match="null values")
if ("pandas_constructor" in str(constructor) and PANDAS_VERSION < (1, 1, 0))
Expand All @@ -307,9 +304,6 @@ def test_key_with_nulls_ignored(
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)

if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

data = {"b": [4, 5, None], "a": [1, 2, 3]}
result = (
nw.from_native(constructor(data))
Expand Down

0 comments on commit 70bc5c6

Please sign in to comment.