From 4974c6869c98ff1a69a2a94e0b5f3dfdaed462bb Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 3 May 2024 17:36:08 +0100 Subject: [PATCH] Basic `2023.12` coverage --- array_api_tests/test_inspection_functions.py | 43 ++++++++++++++ .../test_manipulation_functions.py | 58 +++++++++++++++++++ ...est_operators_and_elementwise_functions.py | 55 ++++++++++++++++++ array_api_tests/test_searching_functions.py | 38 +++++++++++- array_api_tests/test_statistical_functions.py | 10 ++++ 5 files changed, 203 insertions(+), 1 deletion(-) create mode 100644 array_api_tests/test_inspection_functions.py diff --git a/array_api_tests/test_inspection_functions.py b/array_api_tests/test_inspection_functions.py new file mode 100644 index 00000000..eaac0276 --- /dev/null +++ b/array_api_tests/test_inspection_functions.py @@ -0,0 +1,43 @@ +import pytest +from hypothesis import given, strategies as st + +from . import xp + +pytestmark = pytest.mark.min_version("2023.12") + + +def test_array_namespace_info(): + out = xp.__array_namespace_info__() + + capabilities = out.capabilities() + assert isinstance(capabilities, dict) + + out.default_device() + + default_dtypes = out.default_dtypes() + assert isinstance(default_dtypes, dict) + assert {"real floating", "complex floating", "integral", "indexing"}.issubset(set(default_dtypes.keys())) + + devices = out.devices() + assert isinstance(devices, list) + + +atomic_kinds = [ + "bool", + "signed integer", + "unsigned integer", + "real floating", + "complex floating", +] + + +@given( + st.one_of( + st.none(), + st.sampled_from(atomic_kinds + ["integral", "numeric"]), + st.lists(st.sampled_from(atomic_kinds), unique=True, min_size=1).map(tuple), + ) +) +def test_array_namespace_info_dtypes(kind): + out = xp.__array_namespace_info__().dtypes(kind=kind) + assert isinstance(out, dict) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index cb16de95..16e48632 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -149,6 +149,31 @@ def test_expand_dims(x, axis): ) +@pytest.mark.min_version("2023.12") +@given(x=hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_dims=1)), data=st.data()) +def test_moveaxis(x, data): + source = data.draw( + st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim), label="source" + ) + if isinstance(source, int): + destination = data.draw(st.integers(-x.ndim, x.ndim - 1), label="destination") + else: + assert isinstance(source, tuple) # sanity check + destination = data.draw( + st.lists( + st.integers(-x.ndim, x.ndim - 1), + min_size=len(source), + max_size=len(source), + unique_by=lambda n: n if n >= 0 else x.ndim + n, + ).map(tuple), + label="destination" + ) + + out = xp.moveaxis(x, source, destination) + + ph.assert_dtype("moveaxis", in_dtype=x.dtype, out_dtype=out.dtype) + # TODO: shape and values testing + @pytest.mark.unvectorized @given( x=hh.arrays( @@ -253,6 +278,20 @@ def reshape_shapes(draw, shape): return tuple(rshape) +@pytest.mark.min_version("2023.12") +@given( + x=hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_dims=1)), + repeats=st.integers(1, 4), +) +def test_repeat(x, repeats): + # TODO: test array repeats and non-None axis, adjust shape and value testing accordingly + out = xp.repeat(x, repeats) + ph.assert_dtype("repeat", in_dtype=x.dtype, out_dtype=out.dtype) + expected_shape = (math.prod(x.shape) * repeats,) + ph.assert_shape("repeat", out_shape=out.shape, expected=expected_shape) + # TODO: values testing + + @pytest.mark.unvectorized @pytest.mark.skip("flaky") # TODO: fix! @given( @@ -371,3 +410,22 @@ def test_stack(shape, dtypes, kw, data): out_val=out[out_idx], kw=kw, ) + + +@pytest.mark.min_version("2023.12") +@given(x=hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), data=st.data()) +def test_tile(x, data): + repetitions = data.draw(st.lists(st.integers(1, 4), min_size=1, max_size=x.ndim + 1).map(tuple), label="repetitions") + out = xp.tile(x, repetitions) + ph.assert_dtype("tile", in_dtype=x.dtype, out_dtype=out.dtype) + # TODO: shapes and values testing + + +@pytest.mark.min_version("2023.12") +@given(x=hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_dims=1)), data=st.data()) +def test_unstack(x, data): + axis = data.draw(st.integers(min_value=-x.ndim, max_value=x.ndim - 1), label="axis") + kw = data.draw(hh.specified_kwargs(("axis", axis, 0)), label="kw") + out = xp.asarray(xp.unstack(x, **kw), dtype=x.dtype) + ph.assert_dtype("unstack", in_dtype=x.dtype, out_dtype=out.dtype) + # TODO: shapes and values testing \ No newline at end of file diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 27144847..fe0ffc5d 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -933,6 +933,16 @@ def test_ceil(x): unary_assert_against_refimpl("ceil", x, out, math.ceil, strict_check=True) +@pytest.mark.min_version("2023.12") +@given(hh.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +def test_clip(x): + # TODO: test min/max kwargs, adjust values testing accordingly + out = xp.clip(x) + ph.assert_dtype("clip", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("clip", out_shape=out.shape, expected=x.shape) + ph.assert_array_elements("clip", out=out, expected=x) + + if api_version >= "2022.12": @given(hh.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes())) @@ -943,6 +953,15 @@ def test_conj(x): unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate")) +@pytest.mark.min_version("2023.12") +@given(*hh.two_mutual_arrays(dh.real_float_dtypes)) +def test_copysign(x1, x2): + out = xp.copysign(x1, x2) + ph.assert_dtype("copysign", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) + ph.assert_result_shape("copysign", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) + # TODO: values testing + + @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_cos(x): out = xp.cos(x) @@ -1095,6 +1114,15 @@ def test_greater_equal(ctx, data): ) +@pytest.mark.min_version("2023.12") +@given(*hh.two_mutual_arrays(dh.real_float_dtypes)) +def test_hypot(x1, x2): + out = xp.hypot(x1, x2) + ph.assert_dtype("hypot", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) + ph.assert_result_shape("hypot", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) + binary_assert_against_refimpl("hypot", x1, x2, out, math.hypot) + + if api_version >= "2022.12": @given(hh.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes())) @@ -1261,6 +1289,24 @@ def test_logical_xor(x1, x2): ) +@pytest.mark.min_version("2023.12") +@given(*hh.two_mutual_arrays(dh.real_float_dtypes)) +def test_maximum(x1, x2): + out = xp.maximum(x1, x2) + ph.assert_dtype("maximum", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) + ph.assert_result_shape("maximum", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) + binary_assert_against_refimpl("maximum", x1, x2, out, max, strict_check=True) + + +@pytest.mark.min_version("2023.12") +@given(*hh.two_mutual_arrays(dh.real_float_dtypes)) +def test_minimum(x1, x2): + out = xp.minimum(x1, x2) + ph.assert_dtype("minimum", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) + ph.assert_result_shape("minimum", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) + binary_assert_against_refimpl("minimum", x1, x2, out, min, strict_check=True) + + @pytest.mark.parametrize("ctx", make_binary_params("multiply", dh.numeric_dtypes)) @given(data=st.data()) def test_multiply(ctx, data): @@ -1380,6 +1426,15 @@ def test_round(x): unary_assert_against_refimpl("round", x, out, round, strict_check=True) +@pytest.mark.min_version("2023.12") +@given(hh.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +def test_signbit(x): + out = xp.signbit(x) + ph.assert_dtype("signbit", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) + ph.assert_shape("signbit", out_shape=out.shape, expected=x.shape) + # TODO: values testing + + @given(hh.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(), elements=finite_kw)) def test_sign(x): out = xp.sign(x) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index ee7d4e9b..608547ec 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -1,7 +1,7 @@ import math import pytest -from hypothesis import given +from hypothesis import given, note from hypothesis import strategies as st from . import _array_module as xp @@ -167,3 +167,39 @@ def test_where(shapes, dtypes, data): out_repr=f"out[{idx}]", out_val=out[idx] ) + + +@pytest.mark.min_version("2023.12") +@given(data=st.data()) +def test_searchsorted(data): + # TODO: test side="right" + _x1 = data.draw( + st.lists(xps.from_dtype(dh.default_float), min_size=1, unique=True), + label="_x1", + ) + x1 = xp.asarray(_x1, dtype=dh.default_float) + if data.draw(st.booleans(), label="use sorter?"): + sorter = data.draw( + st.permutations(_x1).map(lambda o: xp.asarray(o, dtype=dh.default_float)), + label="sorter", + ) + else: + sorter = None + x1 = xp.sort(x1) + note(f"{x1=}") + x2 = data.draw( + st.lists(st.sampled_from(_x1), unique=True, min_size=1).map( + lambda o: xp.asarray(o, dtype=dh.default_float) + ), + label="x2", + ) + + out = xp.searchsorted(x1, x2, sorter=sorter) + + ph.assert_dtype( + "searchsorted", + in_dtype=[x1.dtype, x2.dtype], + out_dtype=out.dtype, + expected=xp.__array_namespace_info__().default_dtypes()["indexing"], + ) + # TODO: shapes and values testing \ No newline at end of file diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 3cce37e0..f1ba2e1d 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -16,6 +16,16 @@ from .typing import DataType +@pytest.mark.min_version("2023.12") +@given(hh.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_dims=1, max_dims=1))) +def test_cumulative_sum(x): + # TODO: test kwargs + diff shapes, adjust shape and values testing accordingly + out = xp.cumulative_sum(x) + # TODO: assert dtype + ph.assert_shape("cumulative_sum", out_shape=out.shape, expected=x.shape) + # TODO: assert values + + def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]: dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype] dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]