Skip to content

Commit

Permalink
feat: pyspark group by n_unique and no aggregation (#1819)
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi authored Jan 19, 2025
1 parent 7e25d00 commit 0ec3c90
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 17 deletions.
25 changes: 19 additions & 6 deletions narwhals/_spark_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,16 @@
if TYPE_CHECKING:
from pyspark.sql import Column
from pyspark.sql import GroupedData
from typing_extensions import Self

from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.typing import IntoSparkLikeExpr
from narwhals.typing import CompliantExpr


POLARS_TO_PYSPARK_AGGREGATIONS = {"len": "count"}


class SparkLikeLazyGroupBy:
def __init__(
self,
self: Self,
df: SparkLikeLazyFrame,
keys: list[str],
drop_null_keys: bool, # noqa: FBT001
Expand All @@ -44,7 +42,7 @@ def __init__(
self._grouped = self._df._native_frame.groupBy(*self._keys)

def agg(
self,
self: Self,
*aggs: IntoSparkLikeExpr,
**named_aggs: IntoSparkLikeExpr,
) -> SparkLikeLazyFrame:
Expand All @@ -62,13 +60,14 @@ def agg(
output_names.extend(expr._output_names)

return agg_pyspark(
self._df,
self._grouped,
exprs,
self._keys,
self._from_native_frame,
)

def _from_native_frame(self, df: SparkLikeLazyFrame) -> SparkLikeLazyFrame:
def _from_native_frame(self: Self, df: SparkLikeLazyFrame) -> SparkLikeLazyFrame:
from narwhals._spark_like.dataframe import SparkLikeLazyFrame

return SparkLikeLazyFrame(
Expand All @@ -87,23 +86,37 @@ def get_spark_function(function_name: str, **kwargs: Any) -> Column:
ddof=kwargs["ddof"],
np_version=parse_version(np.__version__),
)

elif function_name == "len":
# Use count(*) to count all rows including nulls
def _count(*_args: Any, **_kwargs: Any) -> Column:
return F.count("*")

return _count

elif function_name == "n_unique":
from pyspark.sql.types import IntegerType

def _n_unique(_input: Column) -> Column:
return F.count_distinct(_input) + F.max(F.isnull(_input).cast(IntegerType()))

return _n_unique

else:
return getattr(F, function_name)


def agg_pyspark(
df: SparkLikeLazyFrame,
grouped: GroupedData,
exprs: Sequence[CompliantExpr[Column]],
keys: list[str],
from_dataframe: Callable[[Any], SparkLikeLazyFrame],
) -> SparkLikeLazyFrame:
if not exprs:
# No aggregation provided
return from_dataframe(df._native_frame.select(*keys).dropDuplicates(subset=keys))

for expr in exprs:
if not is_simple_aggregation(expr): # pragma: no cover
msg = (
Expand Down
14 changes: 3 additions & 11 deletions tests/group_by_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,6 @@ def test_group_by_depth_1_agg(
expected: dict[str, list[int | float]],
request: pytest.FixtureRequest,
) -> None:
if "pyspark" in str(constructor) and attr == "n_unique":
request.applymarker(pytest.mark.xfail)
if "pandas_pyarrow" in str(constructor) and attr == "var" and PANDAS_VERSION < (2, 1):
# Known issue with variance calculation in pandas 2.0.x with pyarrow backend in groupby operations"
request.applymarker(pytest.mark.xfail)
Expand Down Expand Up @@ -169,11 +167,7 @@ def test_group_by_median(constructor: Constructor) -> None:
assert_equal_data(result, expected)


def test_group_by_n_unique_w_missing(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_group_by_n_unique_w_missing(constructor: Constructor) -> None:
data = {"a": [1, 1, 2], "b": [4, None, 5], "c": [None, None, 7], "d": [1, 1, 3]}
result = (
nw.from_native(constructor(data))
Expand Down Expand Up @@ -343,9 +337,7 @@ def test_key_with_nulls_iter(
assert len(result) == 4


def test_no_agg(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_no_agg(constructor: Constructor) -> None:
result = nw.from_native(constructor(data)).group_by(["a", "b"]).agg().sort("a", "b")

expected = {"a": [1, 3], "b": [4, 6]}
Expand Down Expand Up @@ -425,7 +417,7 @@ def test_all_kind_of_aggs(
# and modin lol https://github.com/modin-project/modin/issues/7414
# and cudf https://github.com/rapidsai/cudf/issues/17649
request.applymarker(pytest.mark.xfail)
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
if "pandas" in str(constructor) and PANDAS_VERSION < (1, 4):
# Bug in old pandas, can't do DataFrameGroupBy[['b', 'b']]
Expand Down

0 comments on commit 0ec3c90

Please sign in to comment.