Skip to content

Commit

Permalink
Basic 2023.12 coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
honno committed May 3, 2024
1 parent e1fe6fb commit 4974c68
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 1 deletion.
43 changes: 43 additions & 0 deletions array_api_tests/test_inspection_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest
from hypothesis import given, strategies as st

from . import xp

pytestmark = pytest.mark.min_version("2023.12")


def test_array_namespace_info():
out = xp.__array_namespace_info__()

capabilities = out.capabilities()
assert isinstance(capabilities, dict)

out.default_device()

default_dtypes = out.default_dtypes()
assert isinstance(default_dtypes, dict)
assert {"real floating", "complex floating", "integral", "indexing"}.issubset(set(default_dtypes.keys()))

devices = out.devices()
assert isinstance(devices, list)


atomic_kinds = [
"bool",
"signed integer",
"unsigned integer",
"real floating",
"complex floating",
]


@given(
st.one_of(
st.none(),
st.sampled_from(atomic_kinds + ["integral", "numeric"]),
st.lists(st.sampled_from(atomic_kinds), unique=True, min_size=1).map(tuple),
)
)
def test_array_namespace_info_dtypes(kind):
out = xp.__array_namespace_info__().dtypes(kind=kind)
assert isinstance(out, dict)
58 changes: 58 additions & 0 deletions array_api_tests/test_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,31 @@ def test_expand_dims(x, axis):
)


@pytest.mark.min_version("2023.12")
@given(x=hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_dims=1)), data=st.data())
def test_moveaxis(x, data):
source = data.draw(
st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim), label="source"
)
if isinstance(source, int):
destination = data.draw(st.integers(-x.ndim, x.ndim - 1), label="destination")
else:
assert isinstance(source, tuple) # sanity check
destination = data.draw(
st.lists(
st.integers(-x.ndim, x.ndim - 1),
min_size=len(source),
max_size=len(source),
unique_by=lambda n: n if n >= 0 else x.ndim + n,
).map(tuple),
label="destination"
)

out = xp.moveaxis(x, source, destination)

ph.assert_dtype("moveaxis", in_dtype=x.dtype, out_dtype=out.dtype)
# TODO: shape and values testing

@pytest.mark.unvectorized
@given(
x=hh.arrays(
Expand Down Expand Up @@ -253,6 +278,20 @@ def reshape_shapes(draw, shape):
return tuple(rshape)


@pytest.mark.min_version("2023.12")
@given(
x=hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_dims=1)),
repeats=st.integers(1, 4),
)
def test_repeat(x, repeats):
# TODO: test array repeats and non-None axis, adjust shape and value testing accordingly
out = xp.repeat(x, repeats)
ph.assert_dtype("repeat", in_dtype=x.dtype, out_dtype=out.dtype)
expected_shape = (math.prod(x.shape) * repeats,)
ph.assert_shape("repeat", out_shape=out.shape, expected=expected_shape)
# TODO: values testing


@pytest.mark.unvectorized
@pytest.mark.skip("flaky") # TODO: fix!
@given(
Expand Down Expand Up @@ -371,3 +410,22 @@ def test_stack(shape, dtypes, kw, data):
out_val=out[out_idx],
kw=kw,
)


@pytest.mark.min_version("2023.12")
@given(x=hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), data=st.data())
def test_tile(x, data):
repetitions = data.draw(st.lists(st.integers(1, 4), min_size=1, max_size=x.ndim + 1).map(tuple), label="repetitions")
out = xp.tile(x, repetitions)
ph.assert_dtype("tile", in_dtype=x.dtype, out_dtype=out.dtype)
# TODO: shapes and values testing


@pytest.mark.min_version("2023.12")
@given(x=hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_dims=1)), data=st.data())
def test_unstack(x, data):
axis = data.draw(st.integers(min_value=-x.ndim, max_value=x.ndim - 1), label="axis")
kw = data.draw(hh.specified_kwargs(("axis", axis, 0)), label="kw")
out = xp.asarray(xp.unstack(x, **kw), dtype=x.dtype)
ph.assert_dtype("unstack", in_dtype=x.dtype, out_dtype=out.dtype)
# TODO: shapes and values testing
55 changes: 55 additions & 0 deletions array_api_tests/test_operators_and_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,16 @@ def test_ceil(x):
unary_assert_against_refimpl("ceil", x, out, math.ceil, strict_check=True)


@pytest.mark.min_version("2023.12")
@given(hh.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
def test_clip(x):
# TODO: test min/max kwargs, adjust values testing accordingly
out = xp.clip(x)
ph.assert_dtype("clip", in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_shape("clip", out_shape=out.shape, expected=x.shape)
ph.assert_array_elements("clip", out=out, expected=x)


if api_version >= "2022.12":

@given(hh.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes()))
Expand All @@ -943,6 +953,15 @@ def test_conj(x):
unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate"))


@pytest.mark.min_version("2023.12")
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
def test_copysign(x1, x2):
out = xp.copysign(x1, x2)
ph.assert_dtype("copysign", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
ph.assert_result_shape("copysign", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
# TODO: values testing


@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
def test_cos(x):
out = xp.cos(x)
Expand Down Expand Up @@ -1095,6 +1114,15 @@ def test_greater_equal(ctx, data):
)


@pytest.mark.min_version("2023.12")
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
def test_hypot(x1, x2):
out = xp.hypot(x1, x2)
ph.assert_dtype("hypot", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
ph.assert_result_shape("hypot", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
binary_assert_against_refimpl("hypot", x1, x2, out, math.hypot)


if api_version >= "2022.12":

@given(hh.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes()))
Expand Down Expand Up @@ -1261,6 +1289,24 @@ def test_logical_xor(x1, x2):
)


@pytest.mark.min_version("2023.12")
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
def test_maximum(x1, x2):
out = xp.maximum(x1, x2)
ph.assert_dtype("maximum", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
ph.assert_result_shape("maximum", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
binary_assert_against_refimpl("maximum", x1, x2, out, max, strict_check=True)


@pytest.mark.min_version("2023.12")
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
def test_minimum(x1, x2):
out = xp.minimum(x1, x2)
ph.assert_dtype("minimum", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
ph.assert_result_shape("minimum", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
binary_assert_against_refimpl("minimum", x1, x2, out, min, strict_check=True)


@pytest.mark.parametrize("ctx", make_binary_params("multiply", dh.numeric_dtypes))
@given(data=st.data())
def test_multiply(ctx, data):
Expand Down Expand Up @@ -1380,6 +1426,15 @@ def test_round(x):
unary_assert_against_refimpl("round", x, out, round, strict_check=True)


@pytest.mark.min_version("2023.12")
@given(hh.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
def test_signbit(x):
out = xp.signbit(x)
ph.assert_dtype("signbit", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool)
ph.assert_shape("signbit", out_shape=out.shape, expected=x.shape)
# TODO: values testing


@given(hh.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(), elements=finite_kw))
def test_sign(x):
out = xp.sign(x)
Expand Down
38 changes: 37 additions & 1 deletion array_api_tests/test_searching_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math

import pytest
from hypothesis import given
from hypothesis import given, note
from hypothesis import strategies as st

from . import _array_module as xp
Expand Down Expand Up @@ -167,3 +167,39 @@ def test_where(shapes, dtypes, data):
out_repr=f"out[{idx}]",
out_val=out[idx]
)


@pytest.mark.min_version("2023.12")
@given(data=st.data())
def test_searchsorted(data):
# TODO: test side="right"
_x1 = data.draw(
st.lists(xps.from_dtype(dh.default_float), min_size=1, unique=True),
label="_x1",
)
x1 = xp.asarray(_x1, dtype=dh.default_float)
if data.draw(st.booleans(), label="use sorter?"):
sorter = data.draw(
st.permutations(_x1).map(lambda o: xp.asarray(o, dtype=dh.default_float)),
label="sorter",
)
else:
sorter = None
x1 = xp.sort(x1)
note(f"{x1=}")
x2 = data.draw(
st.lists(st.sampled_from(_x1), unique=True, min_size=1).map(
lambda o: xp.asarray(o, dtype=dh.default_float)
),
label="x2",
)

out = xp.searchsorted(x1, x2, sorter=sorter)

ph.assert_dtype(
"searchsorted",
in_dtype=[x1.dtype, x2.dtype],
out_dtype=out.dtype,
expected=xp.__array_namespace_info__().default_dtypes()["indexing"],
)
# TODO: shapes and values testing
10 changes: 10 additions & 0 deletions array_api_tests/test_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@
from .typing import DataType


@pytest.mark.min_version("2023.12")
@given(hh.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_dims=1, max_dims=1)))
def test_cumulative_sum(x):
# TODO: test kwargs + diff shapes, adjust shape and values testing accordingly
out = xp.cumulative_sum(x)
# TODO: assert dtype
ph.assert_shape("cumulative_sum", out_shape=out.shape, expected=x.shape)
# TODO: assert values


def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]:
dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype]
dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]
Expand Down

0 comments on commit 4974c68

Please sign in to comment.