Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: pyspark group by n_unique and no aggregation #1819

Merged
merged 3 commits into from
Jan 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions narwhals/_spark_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,16 @@
if TYPE_CHECKING:
from pyspark.sql import Column
from pyspark.sql import GroupedData
from typing_extensions import Self

from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.typing import IntoSparkLikeExpr
from narwhals.typing import CompliantExpr


POLARS_TO_PYSPARK_AGGREGATIONS = {"len": "count"}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was unused



class SparkLikeLazyGroupBy:
def __init__(
self,
self: Self,
df: SparkLikeLazyFrame,
keys: list[str],
drop_null_keys: bool, # noqa: FBT001
Expand All @@ -44,7 +42,7 @@ def __init__(
self._grouped = self._df._native_frame.groupBy(*self._keys)

def agg(
self,
self: Self,
*aggs: IntoSparkLikeExpr,
**named_aggs: IntoSparkLikeExpr,
) -> SparkLikeLazyFrame:
Expand All @@ -62,13 +60,14 @@ def agg(
output_names.extend(expr._output_names)

return agg_pyspark(
self._df,
self._grouped,
exprs,
self._keys,
self._from_native_frame,
)

def _from_native_frame(self, df: SparkLikeLazyFrame) -> SparkLikeLazyFrame:
def _from_native_frame(self: Self, df: SparkLikeLazyFrame) -> SparkLikeLazyFrame:
from narwhals._spark_like.dataframe import SparkLikeLazyFrame

return SparkLikeLazyFrame(
Expand All @@ -87,23 +86,37 @@ def get_spark_function(function_name: str, **kwargs: Any) -> Column:
ddof=kwargs["ddof"],
np_version=parse_version(np.__version__),
)

elif function_name == "len":
# Use count(*) to count all rows including nulls
def _count(*_args: Any, **_kwargs: Any) -> Column:
return F.count("*")

return _count

elif function_name == "n_unique":
from pyspark.sql.types import IntegerType

def _n_unique(_input: Column) -> Column:
return F.count_distinct(_input) + F.max(F.isnull(_input).cast(IntegerType()))

return _n_unique

else:
return getattr(F, function_name)


def agg_pyspark(
df: SparkLikeLazyFrame,
grouped: GroupedData,
exprs: Sequence[CompliantExpr[Column]],
keys: list[str],
from_dataframe: Callable[[Any], SparkLikeLazyFrame],
) -> SparkLikeLazyFrame:
if not exprs:
# No aggregation provided
return from_dataframe(df._native_frame.select(*keys).dropDuplicates(subset=keys))

for expr in exprs:
if not is_simple_aggregation(expr): # pragma: no cover
msg = (
Expand Down
14 changes: 3 additions & 11 deletions tests/group_by_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ def test_group_by_depth_1_agg(
expected: dict[str, list[int | float]],
request: pytest.FixtureRequest,
) -> None:
if "pyspark" in str(constructor) and attr == "n_unique":
request.applymarker(pytest.mark.xfail)
if "pandas_pyarrow" in str(constructor) and attr == "var" and PANDAS_VERSION < (2, 1):
# Known issue with variance calculation in pandas 2.0.x with pyarrow backend in groupby operations"
request.applymarker(pytest.mark.xfail)
Expand Down Expand Up @@ -170,11 +168,7 @@ def test_group_by_median(constructor: Constructor) -> None:
assert_equal_data(result, expected)


def test_group_by_n_unique_w_missing(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_group_by_n_unique_w_missing(constructor: Constructor) -> None:
data = {"a": [1, 1, 2], "b": [4, None, 5], "c": [None, None, 7], "d": [1, 1, 3]}
result = (
nw.from_native(constructor(data))
Expand Down Expand Up @@ -344,9 +338,7 @@ def test_key_with_nulls_iter(
assert len(result) == 4


def test_no_agg(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_no_agg(constructor: Constructor) -> None:
result = nw.from_native(constructor(data)).group_by(["a", "b"]).agg().sort("a", "b")

expected = {"a": [1, 3], "b": [4, 6]}
Expand Down Expand Up @@ -426,7 +418,7 @@ def test_all_kind_of_aggs(
# and modin lol https://github.com/modin-project/modin/issues/7414
# and cudf https://github.com/rapidsai/cudf/issues/17649
request.applymarker(pytest.mark.xfail)
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
if "pandas" in str(constructor) and PANDAS_VERSION < (1, 4):
# Bug in old pandas, can't do DataFrameGroupBy[['b', 'b']]
Expand Down
Loading