From aedb28c2759225711d522059902c2aaf88fecb45 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sat, 18 Jan 2025 00:30:18 +0100 Subject: [PATCH] unit test --- narwhals/_dask/selectors.py | 2 +- narwhals/stable/v1/selectors.py | 2 + tests/selectors_test.py | 153 ++++++++++++++++++++++++++------ 3 files changed, 129 insertions(+), 28 deletions(-) diff --git a/narwhals/_dask/selectors.py b/narwhals/_dask/selectors.py index 960c919fa..b42a610f6 100644 --- a/narwhals/_dask/selectors.py +++ b/narwhals/_dask/selectors.py @@ -99,7 +99,7 @@ def datetime( self: Self, time_unit: TimeUnit | Collection[TimeUnit] | None, time_zone: str | timezone | Collection[str | timezone | None] | None, - ) -> DaskSelector: + ) -> DaskSelector: # pragma: no cover from narwhals.utils import _parse_datetime_selector_to_datetimes datetime_dtypes = _parse_datetime_selector_to_datetimes( diff --git a/narwhals/stable/v1/selectors.py b/narwhals/stable/v1/selectors.py index 0d82484e9..5bd2ac938 100644 --- a/narwhals/stable/v1/selectors.py +++ b/narwhals/stable/v1/selectors.py @@ -4,6 +4,7 @@ from narwhals.selectors import boolean from narwhals.selectors import by_dtype from narwhals.selectors import categorical +from narwhals.selectors import datetime from narwhals.selectors import numeric from narwhals.selectors import string @@ -12,6 +13,7 @@ "boolean", "by_dtype", "categorical", + "datetime", "numeric", "string", ] diff --git a/tests/selectors_test.py b/tests/selectors_test.py index 80aa64803..fb331f0e2 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -1,16 +1,16 @@ from __future__ import annotations +from datetime import datetime +from datetime import timezone +from typing import Literal + import pandas as pd import pyarrow as pa import pytest +from zoneinfo import ZoneInfo import narwhals.stable.v1 as nw -from narwhals.stable.v1.selectors import all -from narwhals.stable.v1.selectors import boolean -from narwhals.stable.v1.selectors import by_dtype -from narwhals.stable.v1.selectors import categorical -from narwhals.stable.v1.selectors import numeric -from narwhals.stable.v1.selectors import string +import narwhals.stable.v1.selectors as ncs from tests.utils import PYARROW_VERSION from tests.utils import Constructor from tests.utils import assert_equal_data @@ -27,34 +27,34 @@ def test_selectors(constructor: Constructor, request: pytest.FixtureRequest) -> if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) - result = df.select(by_dtype([nw.Int64, nw.Float64]) + 1) + result = df.select(ncs.by_dtype([nw.Int64, nw.Float64]) + 1) expected = {"a": [2, 2, 3], "c": [5.1, 6.0, 7.0]} assert_equal_data(result, expected) def test_numeric(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): + if "pyspark" in str(constructor) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) - result = df.select(numeric() + 1) + result = df.select(ncs.numeric() + 1) expected = {"a": [2, 2, 3], "c": [5.1, 6.0, 7.0]} assert_equal_data(result, expected) def test_boolean(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): + if "pyspark" in str(constructor) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) - result = df.select(boolean()) + result = df.select(ncs.boolean()) expected = {"d": [True, False, True]} assert_equal_data(result, expected) def test_string(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): + if "pyspark" in str(constructor) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) - result = df.select(string()) + result = df.select(ncs.string()) expected = {"b": ["a", "b", "c"]} assert_equal_data(result, expected) @@ -72,22 +72,121 @@ def test_categorical( expected = {"b": ["a", "b", "c"]} df = nw.from_native(constructor(data)).with_columns(nw.col("b").cast(nw.Categorical)) - result = df.select(categorical()) + result = df.select(ncs.categorical()) assert_equal_data(result, expected) +def test_datetime(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if ( + "pyspark" in str(constructor) + or "duckdb" in str(constructor) + or "dask" in str(constructor) + ): + request.applymarker(pytest.mark.xfail) + + ts1 = datetime(2000, 11, 20, 18, 12, 16, 600000) + ts2 = datetime(2020, 10, 30, 10, 20, 25, 123000) + + utc_tz = timezone.utc + berlin_tz = ZoneInfo("Europe/Berlin") + + data = { + "numeric": [3.14, 6.28], + "ts": [ts1, ts2], + "ts_utc": [ts1.astimezone(utc_tz), ts2.astimezone(utc_tz)], + "ts_berlin": [ts1.astimezone(berlin_tz), ts2.astimezone(berlin_tz)], + } + time_units: list[Literal["ns", "us", "ms", "s"]] = ["ms", "us", "ns"] + + df = nw.from_native(constructor(data)).select( + nw.col("numeric"), + *[ + nw.col("ts").cast(nw.Datetime(time_unit=tu)).alias(f"ts_{tu}") + for tu in time_units + ], + *[ + nw.col("ts_utc") + .cast(nw.Datetime(time_zone="UTC", time_unit=tu)) + .alias(f"ts_utc_{tu}") + for tu in time_units + ], + *[ + nw.col("ts_berlin") + .cast(nw.Datetime(time_zone="Europe/Berlin", time_unit=tu)) + .alias(f"ts_berlin_{tu}") + for tu in time_units + ], + ) + + assert df.select(ncs.datetime()).collect_schema().names() == [ + "ts_ms", + "ts_us", + "ts_ns", + "ts_utc_ms", + "ts_utc_us", + "ts_utc_ns", + "ts_berlin_ms", + "ts_berlin_us", + "ts_berlin_ns", + ] + assert df.select(ncs.datetime(time_unit="ms")).collect_schema().names() == [ + "ts_ms", + "ts_utc_ms", + "ts_berlin_ms", + ] + assert df.select(ncs.datetime(time_unit=["us", "ns"])).collect_schema().names() == [ + "ts_us", + "ts_ns", + "ts_utc_us", + "ts_utc_ns", + "ts_berlin_us", + "ts_berlin_ns", + ] + + assert df.select(ncs.datetime(time_zone=None)).collect_schema().names() == [ + "ts_ms", + "ts_us", + "ts_ns", + ] + assert df.select(ncs.datetime(time_zone="*")).collect_schema().names() == [ + "ts_utc_ms", + "ts_utc_us", + "ts_utc_ns", + "ts_berlin_ms", + "ts_berlin_us", + "ts_berlin_ns", + ] + assert df.select( + ncs.datetime(time_zone=[None, "Europe/Berlin"]) + ).collect_schema().names() == [ + "ts_ms", + "ts_us", + "ts_ns", + "ts_berlin_ms", + "ts_berlin_us", + "ts_berlin_ns", + ] + + assert df.select( + ncs.datetime(time_unit="ns", time_zone=[None, "Europe/Berlin"]) + ).collect_schema().names() == ["ts_ns", "ts_berlin_ns"] + assert df.select( + ncs.datetime(time_unit=["ms", "us"], time_zone=[None, "Europe/Berlin"]) + ).collect_schema().names() == ["ts_ms", "ts_us", "ts_berlin_ms", "ts_berlin_us"] + + @pytest.mark.parametrize( ("selector", "expected"), [ - (numeric() | boolean(), ["a", "c", "d"]), - (numeric() & boolean(), []), - (numeric() & by_dtype(nw.Int64), ["a"]), - (numeric() | by_dtype(nw.Int64), ["a", "c"]), - (~numeric(), ["b", "d"]), - (boolean() & True, ["d"]), - (boolean() | True, ["d"]), - (numeric() - 1, ["a", "c"]), - (all(), ["a", "b", "c", "d"]), + (ncs.numeric() | ncs.boolean(), ["a", "c", "d"]), + (ncs.numeric() & ncs.boolean(), []), + (ncs.numeric() & ncs.by_dtype(nw.Int64), ["a"]), + (ncs.numeric() | ncs.by_dtype(nw.Int64), ["a", "c"]), + (~ncs.numeric(), ["b", "d"]), + (ncs.boolean() & True, ["d"]), + (ncs.boolean() | True, ["d"]), + (ncs.numeric() - 1, ["a", "c"]), + (ncs.all(), ["a", "b", "c", "d"]), ], ) def test_set_ops( @@ -96,7 +195,7 @@ def test_set_ops( expected: list[str], request: pytest.FixtureRequest, ) -> None: - if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): + if "pyspark" in str(constructor) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(selector).collect_schema().names() @@ -111,8 +210,8 @@ def test_set_ops_invalid( request.applymarker(pytest.mark.xfail) df = nw.from_native(invalid_constructor(data)) with pytest.raises((NotImplementedError, ValueError)): - df.select(1 - numeric()) + df.select(1 - ncs.numeric()) with pytest.raises((NotImplementedError, ValueError)): - df.select(1 | numeric()) + df.select(1 | ncs.numeric()) with pytest.raises((NotImplementedError, ValueError)): - df.select(1 & numeric()) + df.select(1 & ncs.numeric())