Skip to content

Commit

Permalink
unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Jan 17, 2025
1 parent 9d9b735 commit aedb28c
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 28 deletions.
2 changes: 1 addition & 1 deletion narwhals/_dask/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def datetime(
self: Self,
time_unit: TimeUnit | Collection[TimeUnit] | None,
time_zone: str | timezone | Collection[str | timezone | None] | None,
) -> DaskSelector:
) -> DaskSelector: # pragma: no cover
from narwhals.utils import _parse_datetime_selector_to_datetimes

datetime_dtypes = _parse_datetime_selector_to_datetimes(
Expand Down
2 changes: 2 additions & 0 deletions narwhals/stable/v1/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from narwhals.selectors import boolean
from narwhals.selectors import by_dtype
from narwhals.selectors import categorical
from narwhals.selectors import datetime
from narwhals.selectors import numeric
from narwhals.selectors import string

Expand All @@ -12,6 +13,7 @@
"boolean",
"by_dtype",
"categorical",
"datetime",
"numeric",
"string",
]
153 changes: 126 additions & 27 deletions tests/selectors_test.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from __future__ import annotations

from datetime import datetime
from datetime import timezone
from typing import Literal

import pandas as pd
import pyarrow as pa
import pytest
from zoneinfo import ZoneInfo

import narwhals.stable.v1 as nw
from narwhals.stable.v1.selectors import all
from narwhals.stable.v1.selectors import boolean
from narwhals.stable.v1.selectors import by_dtype
from narwhals.stable.v1.selectors import categorical
from narwhals.stable.v1.selectors import numeric
from narwhals.stable.v1.selectors import string
import narwhals.stable.v1.selectors as ncs
from tests.utils import PYARROW_VERSION
from tests.utils import Constructor
from tests.utils import assert_equal_data
Expand All @@ -27,34 +27,34 @@ def test_selectors(constructor: Constructor, request: pytest.FixtureRequest) ->
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = df.select(by_dtype([nw.Int64, nw.Float64]) + 1)
result = df.select(ncs.by_dtype([nw.Int64, nw.Float64]) + 1)
expected = {"a": [2, 2, 3], "c": [5.1, 6.0, 7.0]}
assert_equal_data(result, expected)


def test_numeric(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
if "pyspark" in str(constructor) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = df.select(numeric() + 1)
result = df.select(ncs.numeric() + 1)
expected = {"a": [2, 2, 3], "c": [5.1, 6.0, 7.0]}
assert_equal_data(result, expected)


def test_boolean(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
if "pyspark" in str(constructor) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = df.select(boolean())
result = df.select(ncs.boolean())
expected = {"d": [True, False, True]}
assert_equal_data(result, expected)


def test_string(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
if "pyspark" in str(constructor) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = df.select(string())
result = df.select(ncs.string())
expected = {"b": ["a", "b", "c"]}
assert_equal_data(result, expected)

Expand All @@ -72,22 +72,121 @@ def test_categorical(
expected = {"b": ["a", "b", "c"]}

df = nw.from_native(constructor(data)).with_columns(nw.col("b").cast(nw.Categorical))
result = df.select(categorical())
result = df.select(ncs.categorical())
assert_equal_data(result, expected)


def test_datetime(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if (
"pyspark" in str(constructor)
or "duckdb" in str(constructor)
or "dask" in str(constructor)
):
request.applymarker(pytest.mark.xfail)

ts1 = datetime(2000, 11, 20, 18, 12, 16, 600000)
ts2 = datetime(2020, 10, 30, 10, 20, 25, 123000)

utc_tz = timezone.utc
berlin_tz = ZoneInfo("Europe/Berlin")

data = {
"numeric": [3.14, 6.28],
"ts": [ts1, ts2],
"ts_utc": [ts1.astimezone(utc_tz), ts2.astimezone(utc_tz)],
"ts_berlin": [ts1.astimezone(berlin_tz), ts2.astimezone(berlin_tz)],
}
time_units: list[Literal["ns", "us", "ms", "s"]] = ["ms", "us", "ns"]

df = nw.from_native(constructor(data)).select(
nw.col("numeric"),
*[
nw.col("ts").cast(nw.Datetime(time_unit=tu)).alias(f"ts_{tu}")
for tu in time_units
],
*[
nw.col("ts_utc")
.cast(nw.Datetime(time_zone="UTC", time_unit=tu))
.alias(f"ts_utc_{tu}")
for tu in time_units
],
*[
nw.col("ts_berlin")
.cast(nw.Datetime(time_zone="Europe/Berlin", time_unit=tu))
.alias(f"ts_berlin_{tu}")
for tu in time_units
],
)

assert df.select(ncs.datetime()).collect_schema().names() == [
"ts_ms",
"ts_us",
"ts_ns",
"ts_utc_ms",
"ts_utc_us",
"ts_utc_ns",
"ts_berlin_ms",
"ts_berlin_us",
"ts_berlin_ns",
]
assert df.select(ncs.datetime(time_unit="ms")).collect_schema().names() == [
"ts_ms",
"ts_utc_ms",
"ts_berlin_ms",
]
assert df.select(ncs.datetime(time_unit=["us", "ns"])).collect_schema().names() == [
"ts_us",
"ts_ns",
"ts_utc_us",
"ts_utc_ns",
"ts_berlin_us",
"ts_berlin_ns",
]

assert df.select(ncs.datetime(time_zone=None)).collect_schema().names() == [
"ts_ms",
"ts_us",
"ts_ns",
]
assert df.select(ncs.datetime(time_zone="*")).collect_schema().names() == [
"ts_utc_ms",
"ts_utc_us",
"ts_utc_ns",
"ts_berlin_ms",
"ts_berlin_us",
"ts_berlin_ns",
]
assert df.select(
ncs.datetime(time_zone=[None, "Europe/Berlin"])
).collect_schema().names() == [
"ts_ms",
"ts_us",
"ts_ns",
"ts_berlin_ms",
"ts_berlin_us",
"ts_berlin_ns",
]

assert df.select(
ncs.datetime(time_unit="ns", time_zone=[None, "Europe/Berlin"])
).collect_schema().names() == ["ts_ns", "ts_berlin_ns"]
assert df.select(
ncs.datetime(time_unit=["ms", "us"], time_zone=[None, "Europe/Berlin"])
).collect_schema().names() == ["ts_ms", "ts_us", "ts_berlin_ms", "ts_berlin_us"]


@pytest.mark.parametrize(
("selector", "expected"),
[
(numeric() | boolean(), ["a", "c", "d"]),
(numeric() & boolean(), []),
(numeric() & by_dtype(nw.Int64), ["a"]),
(numeric() | by_dtype(nw.Int64), ["a", "c"]),
(~numeric(), ["b", "d"]),
(boolean() & True, ["d"]),
(boolean() | True, ["d"]),
(numeric() - 1, ["a", "c"]),
(all(), ["a", "b", "c", "d"]),
(ncs.numeric() | ncs.boolean(), ["a", "c", "d"]),
(ncs.numeric() & ncs.boolean(), []),
(ncs.numeric() & ncs.by_dtype(nw.Int64), ["a"]),
(ncs.numeric() | ncs.by_dtype(nw.Int64), ["a", "c"]),
(~ncs.numeric(), ["b", "d"]),
(ncs.boolean() & True, ["d"]),
(ncs.boolean() | True, ["d"]),
(ncs.numeric() - 1, ["a", "c"]),
(ncs.all(), ["a", "b", "c", "d"]),
],
)
def test_set_ops(
Expand All @@ -96,7 +195,7 @@ def test_set_ops(
expected: list[str],
request: pytest.FixtureRequest,
) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
if "pyspark" in str(constructor) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = df.select(selector).collect_schema().names()
Expand All @@ -111,8 +210,8 @@ def test_set_ops_invalid(
request.applymarker(pytest.mark.xfail)
df = nw.from_native(invalid_constructor(data))
with pytest.raises((NotImplementedError, ValueError)):
df.select(1 - numeric())
df.select(1 - ncs.numeric())
with pytest.raises((NotImplementedError, ValueError)):
df.select(1 | numeric())
df.select(1 | ncs.numeric())
with pytest.raises((NotImplementedError, ValueError)):
df.select(1 & numeric())
df.select(1 & ncs.numeric())

0 comments on commit aedb28c

Please sign in to comment.