Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add SparkExpr.cast for basic types #1812

Merged
merged 12 commits into from
Jan 17, 2025
11 changes: 11 additions & 0 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace
from narwhals._spark_like.utils import get_column_name
from narwhals._spark_like.utils import maybe_evaluate
from narwhals._spark_like.utils import narwhals_to_native_dtype
from narwhals.typing import CompliantExpr
from narwhals.utils import Implementation
from narwhals.utils import parse_version
Expand All @@ -20,6 +21,7 @@

from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.namespace import SparkLikeNamespace
from narwhals.dtypes import DType
from narwhals.utils import Version


Expand Down Expand Up @@ -284,6 +286,15 @@ def any(self) -> Self:

return self._from_call(F.bool_or, "any", returns_scalar=True)

def cast(self: Self, dtype: DType | type[DType]) -> Self:
def _cast(_input: Column, dtype: DType | type[DType]) -> Column:
spark_dtype = narwhals_to_native_dtype(dtype, self._version)
return _input.cast(spark_dtype)

return self._from_call(
_cast, "cast", dtype=dtype, returns_scalar=self._returns_scalar
)

def count(self) -> Self:
from pyspark.sql import functions as F # noqa: N812

Expand Down
22 changes: 11 additions & 11 deletions narwhals/_spark_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _all(df: SparkLikeLazyFrame) -> list[Column]:

return [F.col(col_name) for col_name in df.columns]

return SparkLikeExpr( # type: ignore[abstract]
return SparkLikeExpr(
call=_all,
depth=0,
function_name="all",
Expand All @@ -63,7 +63,7 @@ def _lit(_: SparkLikeLazyFrame) -> list[Column]:

return [F.lit(value).alias("literal")]

return SparkLikeExpr( # type: ignore[abstract]
return SparkLikeExpr(
call=_lit,
depth=0,
function_name="lit",
Expand All @@ -81,7 +81,7 @@ def func(_: SparkLikeLazyFrame) -> list[Column]:

return [F.count("*").alias("len")]

return SparkLikeExpr( # type: ignore[abstract]
return SparkLikeExpr(
func,
depth=0,
function_name="len",
Expand All @@ -101,7 +101,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
col_name = get_column_name(df, cols[0])
return [reduce(operator.and_, cols).alias(col_name)]

return SparkLikeExpr( # type: ignore[abstract]
return SparkLikeExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="all_horizontal",
Expand All @@ -121,7 +121,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
col_name = get_column_name(df, cols[0])
return [reduce(operator.or_, cols).alias(col_name)]

return SparkLikeExpr( # type: ignore[abstract]
return SparkLikeExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="any_horizontal",
Expand All @@ -148,7 +148,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
).alias(col_name)
]

return SparkLikeExpr( # type: ignore[abstract]
return SparkLikeExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="sum_horizontal",
Expand Down Expand Up @@ -179,7 +179,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
).alias(col_name)
]

return SparkLikeExpr( # type: ignore[abstract]
return SparkLikeExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="mean_horizontal",
Expand All @@ -201,7 +201,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
col_name = get_column_name(df, cols[0])
return [F.greatest(*cols).alias(col_name)]

return SparkLikeExpr( # type: ignore[abstract]
return SparkLikeExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="max_horizontal",
Expand All @@ -223,7 +223,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
col_name = get_column_name(df, cols[0])
return [F.least(*cols).alias(col_name)]

return SparkLikeExpr( # type: ignore[abstract]
return SparkLikeExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="min_horizontal",
Expand Down Expand Up @@ -320,7 +320,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:

return [result]

return SparkLikeExpr( # type: ignore[abstract]
return SparkLikeExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="concat_str",
Expand Down Expand Up @@ -392,7 +392,7 @@ def __call__(self, df: SparkLikeLazyFrame) -> list[Column]:
def then(self, value: SparkLikeExpr | Any) -> SparkLikeThen:
self._then_value = value

return SparkLikeThen( # type: ignore[abstract]
return SparkLikeThen(
self,
depth=0,
function_name="whenthen",
Expand Down
42 changes: 41 additions & 1 deletion narwhals/_spark_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from typing import TYPE_CHECKING
from typing import Any

from pyspark.sql import types as pyspark_types

from narwhals.exceptions import InvalidIntoExprError
from narwhals.utils import import_dtypes_module
from narwhals.utils import isinstance_or_issubclass

if TYPE_CHECKING:
from pyspark.sql import Column
from pyspark.sql import types as pyspark_types

from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.typing import IntoSparkLikeExpr
Expand All @@ -35,6 +37,8 @@ def native_to_narwhals_dtype(
return dtypes.Int32()
if isinstance(dtype, pyspark_types.ShortType):
return dtypes.Int16()
if isinstance(dtype, pyspark_types.ByteType):
return dtypes.Int8()
Comment on lines +41 to +42
Copy link
Collaborator Author

@EdAbati EdAbati Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember this came up already @MarcoGorelli . I found the source: https://spark.apache.org/docs/latest/sql-ref-datatypes.html "ByteType is a 1-byte signed integer (within the range of -128 to 127)." πŸ˜…also in the spark.pandas code

string_types = [
pyspark_types.StringType,
pyspark_types.VarcharType,
Expand All @@ -58,6 +62,42 @@ def native_to_narwhals_dtype(
return dtypes.Unknown()


def narwhals_to_native_dtype(
dtype: DType | type[DType], version: Version
) -> pyspark_types.DataType:
dtypes = import_dtypes_module(version)
if isinstance_or_issubclass(dtype, dtypes.Float64):
return pyspark_types.DoubleType()
if isinstance_or_issubclass(dtype, dtypes.Float32):
return pyspark_types.FloatType()
if isinstance_or_issubclass(dtype, dtypes.Int64):
return pyspark_types.LongType()
if isinstance_or_issubclass(dtype, dtypes.Int32):
return pyspark_types.IntegerType()
if isinstance_or_issubclass(dtype, dtypes.Int16):
return pyspark_types.ShortType()
if isinstance_or_issubclass(dtype, dtypes.Int8):
return pyspark_types.ByteType()
if isinstance_or_issubclass(dtype, dtypes.String):
return pyspark_types.StringType()
if isinstance_or_issubclass(dtype, dtypes.Boolean):
return pyspark_types.BooleanType()
if isinstance_or_issubclass(dtype, dtypes.Datetime):
return pyspark_types.TimestampType()
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved
if isinstance_or_issubclass(dtype, dtypes.List): # pragma: no cover
msg = "Converting to List dtype is not supported yet"
return NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover
msg = "Converting to Struct dtype is not supported yet"
return NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
msg = "Converting to Array dtype is not supported yet"
return NotImplementedError(msg)
EdAbati marked this conversation as resolved.
Show resolved Hide resolved

msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved


def get_column_name(df: SparkLikeLazyFrame, column: Column) -> str:
return str(df._native_frame.select(column).columns[0])

Expand Down
59 changes: 32 additions & 27 deletions tests/expr_and_series/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from tests.utils import assert_equal_data
from tests.utils import is_windows

data = {
DATA = {
"a": [1],
"b": [1],
"c": [1],
Expand All @@ -35,7 +35,7 @@
"o": ["a"],
"p": [1],
}
schema = {
SCHEMA = {
"a": nw.Int64,
"b": nw.Int32,
"c": nw.Int16,
Expand All @@ -54,13 +54,15 @@
"p": nw.Int64,
}

SPARK_INCOMPATIBLE_COLUMNS = {"e", "f", "g", "h", "o", "p"}


@pytest.mark.filterwarnings("ignore:casting period[M] values to int64:FutureWarning")
def test_cast(
constructor: Constructor,
request: pytest.FixtureRequest,
) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
if "pyarrow_table_constructor" in str(constructor) and PYARROW_VERSION <= (
15,
Expand All @@ -69,28 +71,20 @@ def test_cast(
if "modin_constructor" in str(constructor):
# TODO(unassigned): in modin, we end up with `'<U0'` dtype
request.applymarker(pytest.mark.xfail)

if "pyspark" in str(constructor):
incompatible_columns = SPARK_INCOMPATIBLE_COLUMNS
else:
incompatible_columns = set()

data = {k: v for k, v in DATA.items() if k not in incompatible_columns}
schema = {k: v for k, v in SCHEMA.items() if k not in incompatible_columns}

df = nw.from_native(constructor(data)).select(
nw.col(key).cast(value) for key, value in schema.items()
)
result = df.select(
nw.col("a").cast(nw.Int32),
nw.col("b").cast(nw.Int16),
nw.col("c").cast(nw.Int8),
nw.col("d").cast(nw.Int64),
nw.col("e").cast(nw.UInt32),
nw.col("f").cast(nw.UInt16),
nw.col("g").cast(nw.UInt8),
nw.col("h").cast(nw.UInt64),
nw.col("i").cast(nw.Float32),
nw.col("j").cast(nw.Float64),
nw.col("k").cast(nw.String),
nw.col("l").cast(nw.Datetime),
nw.col("m").cast(nw.Int8),
nw.col("n").cast(nw.Int8),
nw.col("o").cast(nw.String),
nw.col("p").cast(nw.Duration),
)
expected = {

cast_map = {
"a": nw.Int32,
"b": nw.Int16,
"c": nw.Int8,
Expand All @@ -108,7 +102,10 @@ def test_cast(
"o": nw.String,
"p": nw.Duration,
}
assert dict(result.collect_schema()) == expected
cast_map = {k: v for k, v in cast_map.items() if k not in incompatible_columns}

result = df.select(*[nw.col(key).cast(value) for key, value in cast_map.items()])
assert dict(result.collect_schema()) == cast_map


def test_cast_series(
Expand All @@ -123,8 +120,8 @@ def test_cast_series(
# TODO(unassigned): in modin, we end up with `'<U0'` dtype
request.applymarker(pytest.mark.xfail)
df = (
nw.from_native(constructor_eager(data))
.select(nw.col(key).cast(value) for key, value in schema.items())
nw.from_native(constructor_eager(DATA))
.select(nw.col(key).cast(value) for key, value in SCHEMA.items())
.lazy()
.collect()
)
Expand Down Expand Up @@ -180,11 +177,20 @@ def test_cast_string() -> None:
def test_cast_raises_for_unknown_dtype(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
if "pyarrow_table" in str(constructor) and PYARROW_VERSION < (15,):
# Unsupported cast from string to dictionary using function cast_dictionary
request.applymarker(pytest.mark.xfail)

if "pyspark" in str(constructor):
incompatible_columns = SPARK_INCOMPATIBLE_COLUMNS
else:
incompatible_columns = set()

data = {k: v for k, v in DATA.items() if k not in incompatible_columns}
schema = {k: v for k, v in SCHEMA.items() if k not in incompatible_columns}

df = nw.from_native(constructor(data)).select(
nw.col(key).cast(value) for key, value in schema.items()
)
Expand All @@ -204,7 +210,6 @@ def test_cast_datetime_tz_aware(
or "duckdb" in str(constructor)
or "cudf" in str(constructor) # https://github.com/rapidsai/cudf/issues/16973
or ("pyarrow_table" in str(constructor) and is_windows())
or ("pyspark" in str(constructor))
):
request.applymarker(pytest.mark.xfail)

Expand Down
6 changes: 0 additions & 6 deletions tests/group_by_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,6 @@ def test_key_with_nulls(
# TODO(unassigned): Modin flaky here?
request.applymarker(pytest.mark.skip)

if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

context = (
pytest.raises(NotImplementedError, match="null values")
if ("pandas_constructor" in str(constructor) and PANDAS_VERSION < (1, 1, 0))
Expand All @@ -303,9 +300,6 @@ def test_key_with_nulls_ignored(
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)

if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

data = {"b": [4, 5, None], "a": [1, 2, 3]}
result = (
nw.from_native(constructor(data))
Expand Down
Loading