From 339683c529ccdfcabfa96627607203984061916f Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Fri, 10 Jan 2025 16:42:28 +0000 Subject: [PATCH] feat: implement anti-join, str.len_chars, and null_count for DuckDB (#1777) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- narwhals/_arrow/dataframe.py | 2 +- narwhals/_dask/dataframe.py | 2 +- narwhals/_duckdb/dataframe.py | 6 +---- narwhals/_duckdb/expr.py | 25 ++++++++++++++++----- narwhals/_pandas_like/dataframe.py | 2 +- tests/expr_and_series/null_count_test.py | 8 +------ tests/expr_and_series/str/len_chars_test.py | 6 +---- tests/frame/join_test.py | 3 --- 8 files changed, 25 insertions(+), 29 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index e6bb6fa65..c36f58938 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -333,7 +333,7 @@ def join( self: Self, other: Self, *, - how: Literal["left", "inner", "outer", "cross", "anti", "semi"], + how: Literal["left", "inner", "cross", "anti", "semi"], left_on: str | list[str] | None, right_on: str | list[str] | None, suffix: str, diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 16053d69a..35a0d045c 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -236,7 +236,7 @@ def join( self: Self, other: Self, *, - how: Literal["left", "inner", "outer", "cross", "anti", "semi"] = "inner", + how: Literal["left", "inner", "cross", "anti", "semi"] = "inner", left_on: str | list[str] | None, right_on: str | list[str] | None, suffix: str, diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 33cfc19d2..98eca0bdb 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -215,7 +215,7 @@ def join( self: Self, other: Self, *, - how: Literal["left", "inner", "outer", "cross", "anti", "semi"] = "inner", + how: Literal["left", "inner", "cross", "anti", "semi"] = "inner", left_on: str | list[str] | None, right_on: str | list[str] | None, suffix: str, @@ -226,10 +226,6 @@ def join( right_on = [right_on] original_alias = self._native_frame.alias - if how not in ("inner", "left", "semi", "cross"): - msg = "Only inner and left join is implemented for DuckDB" - raise NotImplementedError(msg) - if how == "cross": if self._backend_version < (1, 1, 4): msg = f"DuckDB>=1.1.4 is required for cross-join, found version: {self._backend_version}" diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index e5e612085..cfd2efdac 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -1,6 +1,5 @@ from __future__ import annotations -import functools from typing import TYPE_CHECKING from typing import Any from typing import Callable @@ -488,6 +487,15 @@ def min(self) -> Self: lambda _input: FunctionExpression("min", _input), "min", returns_scalar=True ) + def null_count(self) -> Self: + from duckdb import FunctionExpression + + return self._from_call( + lambda _input: FunctionExpression("sum", _input.isnull().cast("int")), + "null_count", + returns_scalar=True, + ) + def is_null(self) -> Self: return self._from_call( lambda _input: _input.isnull(), "is_null", returns_scalar=self._returns_scalar @@ -497,11 +505,7 @@ def is_in(self, other: Sequence[Any]) -> Self: from duckdb import ConstantExpression return self._from_call( - lambda _input: functools.reduce( - lambda x, y: x | _input.isin(ConstantExpression(y)), - other[1:], - _input.isin(ConstantExpression(other[0])), - ), + lambda _input: _input.isin(*[ConstantExpression(x) for x in other]), "is_in", returns_scalar=self._returns_scalar, ) @@ -619,6 +623,15 @@ def func(_input: duckdb.Expression) -> duckdb.Expression: func, "slice", returns_scalar=self._compliant_expr._returns_scalar ) + def len_chars(self) -> DuckDBExpr: + from duckdb import FunctionExpression + + return self._compliant_expr._from_call( + lambda _input: FunctionExpression("length", _input), + "len_chars", + returns_scalar=self._compliant_expr._returns_scalar, + ) + def to_lowercase(self) -> DuckDBExpr: from duckdb import FunctionExpression diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index e11c02710..b8b707851 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -512,7 +512,7 @@ def join( self, other: Self, *, - how: Literal["left", "inner", "outer", "cross", "anti", "semi"] = "inner", + how: Literal["left", "inner", "cross", "anti", "semi"] = "inner", left_on: str | list[str] | None, right_on: str | list[str] | None, suffix: str, diff --git a/tests/expr_and_series/null_count_test.py b/tests/expr_and_series/null_count_test.py index a49fd79c8..db162363b 100644 --- a/tests/expr_and_series/null_count_test.py +++ b/tests/expr_and_series/null_count_test.py @@ -1,7 +1,5 @@ from __future__ import annotations -import pytest - import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import ConstructorEager @@ -13,11 +11,7 @@ } -def test_null_count_expr( - constructor: Constructor, request: pytest.FixtureRequest -) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_null_count_expr(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.col("a", "b").null_count()) expected = { diff --git a/tests/expr_and_series/str/len_chars_test.py b/tests/expr_and_series/str/len_chars_test.py index 1a318801a..f9c63e01c 100644 --- a/tests/expr_and_series/str/len_chars_test.py +++ b/tests/expr_and_series/str/len_chars_test.py @@ -1,7 +1,5 @@ from __future__ import annotations -import pytest - import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import ConstructorEager @@ -10,9 +8,7 @@ data = {"a": ["foo", "foobar", "Café", "345", "東京"]} -def test_str_len_chars(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_str_len_chars(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.col("a").str.len_chars()) expected = { diff --git a/tests/frame/join_test.py b/tests/frame/join_test.py index f15a1b79e..5ff112f31 100644 --- a/tests/frame/join_test.py +++ b/tests/frame/join_test.py @@ -166,10 +166,7 @@ def test_anti_join( join_key: list[str], filter_expr: nw.Expr, expected: dict[str, list[Any]], - request: pytest.FixtureRequest, ) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) data = {"antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zor ro": [7.0, 8, 9]} df = nw.from_native(constructor(data)) other = df.filter(filter_expr)