Skip to content

Commit

Permalink
feat: implement when/then/otherwise for DuckDB (#1759)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Jan 8, 2025
1 parent 373320e commit 1f0c718
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 23 deletions.
2 changes: 1 addition & 1 deletion narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class DuckDBExpr(CompliantExpr["duckdb.Expression"]):

def __init__(
self,
call: Callable[[DuckDBLazyFrame], list[duckdb.Expression]],
call: Callable[[DuckDBLazyFrame], Sequence[duckdb.Expression]],
*,
depth: int,
function_name: str,
Expand Down
109 changes: 109 additions & 0 deletions narwhals/_duckdb/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any
from typing import Literal
from typing import Sequence
from typing import cast

from narwhals._duckdb.expr import DuckDBExpr
from narwhals._duckdb.utils import narwhals_to_native_dtype
Expand Down Expand Up @@ -157,6 +158,16 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
kwargs={"exprs": exprs},
)

def when(
self,
*predicates: IntoDuckDBExpr,
) -> DuckDBWhen:
plx = self.__class__(backend_version=self._backend_version, version=self._version)
condition = plx.all_horizontal(*predicates)
return DuckDBWhen(
condition, self._backend_version, returns_scalar=False, version=self._version
)

def col(self, *column_names: str) -> DuckDBExpr:
return DuckDBExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
Expand Down Expand Up @@ -203,3 +214,101 @@ def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]:
version=self._version,
kwargs={},
)


class DuckDBWhen:
def __init__(
self,
condition: DuckDBExpr,
backend_version: tuple[int, ...],
then_value: Any = None,
otherwise_value: Any = None,
*,
returns_scalar: bool,
version: Version,
) -> None:
self._backend_version = backend_version
self._condition = condition
self._then_value = then_value
self._otherwise_value = otherwise_value
self._returns_scalar = returns_scalar
self._version = version

def __call__(self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]:
from duckdb import CaseExpression
from duckdb import ConstantExpression

from narwhals._expression_parsing import parse_into_expr

plx = df.__narwhals_namespace__()
condition = parse_into_expr(self._condition, namespace=plx)(df)[0]
condition = cast("duckdb.Expression", condition)

try:
value = parse_into_expr(self._then_value, namespace=plx)(df)[0]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
value = ConstantExpression(self._then_value)
value = cast("duckdb.Expression", value)

if self._otherwise_value is None:
return [CaseExpression(condition=condition, value=value)]
try:
otherwise_expr = parse_into_expr(self._otherwise_value, namespace=plx)
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
return [
CaseExpression(condition=condition, value=value).otherwise(
ConstantExpression(self._otherwise_value)
)
]
otherwise = otherwise_expr(df)[0]
return [CaseExpression(condition=condition, value=value).otherwise(otherwise)]

def then(self, value: DuckDBExpr | Any) -> DuckDBThen:
self._then_value = value

return DuckDBThen(
self,
depth=0,
function_name="whenthen",
root_names=None,
output_names=None,
returns_scalar=self._returns_scalar,
backend_version=self._backend_version,
version=self._version,
kwargs={"value": value},
)


class DuckDBThen(DuckDBExpr):
def __init__(
self,
call: DuckDBWhen,
*,
depth: int,
function_name: str,
root_names: list[str] | None,
output_names: list[str] | None,
returns_scalar: bool,
backend_version: tuple[int, ...],
version: Version,
kwargs: dict[str, Any],
) -> None:
self._backend_version = backend_version
self._version = version
self._call = call
self._depth = depth
self._function_name = function_name
self._root_names = root_names
self._output_names = output_names
self._returns_scalar = returns_scalar
self._kwargs = kwargs

def otherwise(self, value: DuckDBExpr | Any) -> DuckDBExpr:
# type ignore because we are setting the `_call` attribute to a
# callable object of type `DuckDBWhen`, base class has the attribute as
# only a `Callable`
self._call._otherwise_value = value # type: ignore[attr-defined]
self._function_name = "whenotherwise"
return self
3 changes: 2 additions & 1 deletion narwhals/_duckdb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import lru_cache
from typing import TYPE_CHECKING
from typing import Any
from typing import Sequence

from narwhals.dtypes import DType
from narwhals.exceptions import InvalidIntoExprError
Expand Down Expand Up @@ -76,7 +77,7 @@ def parse_exprs_and_named_exprs(

def _columns_from_expr(
df: DuckDBLazyFrame, expr: IntoDuckDBExpr
) -> list[duckdb.Expression]:
) -> Sequence[duckdb.Expression]:
if isinstance(expr, str): # pragma: no cover
from duckdb import ColumnExpression

Expand Down
26 changes: 5 additions & 21 deletions tests/expr_and_series/when_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
}


def test_when(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_when(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.when(nw.col("a") == 1).then(value=3).alias("a_when"))
expected = {
Expand All @@ -28,9 +26,7 @@ def test_when(constructor: Constructor, request: pytest.FixtureRequest) -> None:
assert_equal_data(result, expected)


def test_when_otherwise(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_when_otherwise(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.when(nw.col("a") == 1).then(3).otherwise(6).alias("a_when"))
expected = {
Expand All @@ -39,11 +35,7 @@ def test_when_otherwise(constructor: Constructor, request: pytest.FixtureRequest
assert_equal_data(result, expected)


def test_multiple_conditions(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_multiple_conditions(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(
nw.when(nw.col("a") < 3, nw.col("c") < 5.0).then(3).alias("a_when")
Expand Down Expand Up @@ -85,11 +77,7 @@ def test_value_series(constructor_eager: ConstructorEager) -> None:
assert_equal_data(result, expected)


def test_value_expression(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_value_expression(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.when(nw.col("a") == 1).then(nw.col("a") + 9).alias("a_when"))
expected = {
Expand Down Expand Up @@ -122,11 +110,7 @@ def test_otherwise_series(constructor_eager: ConstructorEager) -> None:
assert_equal_data(result, expected)


def test_otherwise_expression(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_otherwise_expression(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(
nw.when(nw.col("a") == 1).then(-1).otherwise(nw.col("a") + 7).alias("a_when")
Expand Down
9 changes: 9 additions & 0 deletions tpch/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path

import dask.dataframe as dd
import duckdb
import pandas as pd
import polars as pl
import pyarrow as pa
Expand All @@ -29,14 +30,18 @@
"pandas[pyarrow]": (pd, {"engine": "pyarrow", "dtype_backend": "pyarrow"}),
"polars[lazy]": (pl, {}),
"pyarrow": (pa, {}),
"duckdb": (duckdb, {}),
"dask": (dd, {"engine": "pyarrow", "dtype_backend": "pyarrow"}),
}

BACKEND_COLLECT_FUNC_MAP = {
"polars[lazy]": lambda x: x.collect(),
"duckdb": lambda x: x.pl(),
"dask": lambda x: x.compute(),
}

DUCKDB_XFAILS = ["q11", "q14", "q15", "q16", "q18", "q22"]

QUERY_DATA_PATH_MAP = {
"q1": (LINEITEM_PATH,),
"q2": (REGION_PATH, NATION_PATH, SUPPLIER_PATH, PART_PATH, PARTSUPP_PATH),
Expand Down Expand Up @@ -90,6 +95,10 @@ def execute_query(query_id: str) -> None:
data_paths = QUERY_DATA_PATH_MAP[query_id]

for backend, (native_namespace, kwargs) in BACKEND_NAMESPACE_KWARGS_MAP.items():
if backend == "duckdb" and query_id in DUCKDB_XFAILS:
print(f"\nSkipping {query_id} for DuckDB") # noqa: T201
continue

print(f"\nRunning {query_id} with {backend=}") # noqa: T201
result = query_module.query(
*(
Expand Down

0 comments on commit 1f0c718

Please sign in to comment.