Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add LazyFrame.unpivot for spark and duckdb #1890

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 49 additions & 4 deletions narwhals/_duckdb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,6 @@ def collect_schema(self: Self) -> dict[str, DType]:

def unique(self: Self, subset: Sequence[str] | None, keep: str) -> Self:
if subset is not None:
import duckdb

rel = self._native_frame
# Sanitise input
if any(x not in rel.columns for x in subset):
Expand Down Expand Up @@ -388,10 +386,57 @@ def sort(
return self._from_native_frame(result)

def drop_nulls(self: Self, subset: list[str] | None) -> Self:
import duckdb

rel = self._native_frame
subset_ = subset if subset is not None else rel.columns
keep_condition = " and ".join(f'"{col}" is not null' for col in subset_)
query = f"select * from rel where {keep_condition}" # noqa: S608
return self._from_native_frame(duckdb.sql(query))

def unpivot(
self: Self,
on: str | list[str] | None,
index: str | list[str] | None,
variable_name: str | None,
value_name: str | None,
) -> Self:
on_ = [on] if isinstance(on, str) else on
index_ = (
[index]
if isinstance(index, str)
else index
if isinstance(index, list)
else []
)

if on_ is None:
on_ = [c for c in self.columns if c not in index_]

variable_name = variable_name if variable_name is not None else "variable"
value_name = value_name if value_name is not None else "value"

if variable_name == "":
msg = "`variable_name` cannot be empty string for duckdb backend."
raise NotImplementedError(msg)

if value_name == "":
msg = "`value_name` cannot be empty string for duckdb backend."
raise NotImplementedError(msg)
Comment on lines +417 to +423
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried, but it is not of duckdb liking to support these


cols_to_select = ", ".join(
f'"{col}"' for col in [*index_, variable_name, value_name]
)
unpivot_on = ", ".join(f'"{col}"' for col in on_)

rel = self._native_frame # noqa: F841
query = f"""
with unpivot_cte as (
unpivot rel
on {unpivot_on}
into
name {variable_name}
value {value_name}
)
select {cols_to_select}
from unpivot_cte;
""" # noqa: S608
return self._from_native_frame(duckdb.sql(query))
18 changes: 18 additions & 0 deletions narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,21 @@ def join(
return self._from_native_frame(
self_native.join(other, on=left_on, how=how).select(col_order)
)

def unpivot(
self: Self,
on: str | list[str] | None,
index: str | list[str] | None,
variable_name: str | None,
value_name: str | None,
) -> Self:
return self._from_native_frame(
self._native_frame.unpivot(
ids=index,
values=on,
variableColumnName=variable_name
if variable_name is not None
else "variable",
valueColumnName=value_name if value_name is not None else "value",
)
)
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
)
elif "constructor" in metafunc.fixturenames:
if (
any(x in str(metafunc.module) for x in ("unpivot", "from_dict", "from_numpy"))
any(x in str(metafunc.module) for x in ("from_dict", "from_numpy"))
and LAZY_CONSTRUCTORS["duckdb"] in constructors
):
constructors.remove(LAZY_CONSTRUCTORS["duckdb"])
Expand Down
35 changes: 14 additions & 21 deletions tests/frame/unpivot_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from contextlib import nullcontext as does_not_raise
from typing import TYPE_CHECKING
from typing import Any

Expand Down Expand Up @@ -37,14 +38,10 @@
[("b", expected_b_only), (["b", "c"], expected_b_c), (None, expected_b_c)],
)
def test_unpivot_on(
request: pytest.FixtureRequest,
constructor: Constructor,
on: str | list[str] | None,
expected: dict[str, list[float]],
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
result = df.unpivot(on=on, index=["a"]).sort("variable", "a")
assert_equal_data(result, expected)
Expand All @@ -59,28 +56,26 @@ def test_unpivot_on(
],
)
def test_unpivot_var_value_names(
request: pytest.FixtureRequest,
constructor: Constructor,
variable_name: str | None,
value_name: str | None,
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
result = df.unpivot(
on=["b", "c"], index=["a"], variable_name=variable_name, value_name=value_name
context = (
pytest.raises(NotImplementedError)
if ("duckdb" in str(constructor) and any([variable_name == "", value_name == ""]))
else does_not_raise()
)

assert result.collect_schema().names()[-2:] == [variable_name, value_name]
with context:
df = nw.from_native(constructor(data))
result = df.unpivot(
on=["b", "c"], index=["a"], variable_name=variable_name, value_name=value_name
)

assert result.collect_schema().names()[-2:] == [variable_name, value_name]

def test_unpivot_default_var_value_names(
request: pytest.FixtureRequest, constructor: Constructor
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_unpivot_default_var_value_names(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.unpivot(on=["b", "c"], index=["a"])

Expand All @@ -102,10 +97,8 @@ def test_unpivot_mixed_types(
data: dict[str, Any],
expected_dtypes: list[DType],
) -> None:
if (
"cudf" in str(constructor)
or "pyspark" in str(constructor)
or ("pyarrow_table" in str(constructor) and PYARROW_VERSION < (14, 0, 0))
if "cudf" in str(constructor) or (
"pyarrow_table" in str(constructor) and PYARROW_VERSION < (14, 0, 0)
):
request.applymarker(pytest.mark.xfail)

Expand Down
Loading