Skip to content

Commit

Permalink
implementing a wrapper to fix test errors
Browse files Browse the repository at this point in the history
  • Loading branch information
ohrechykha committed Aug 8, 2024
1 parent fc0de45 commit 32eae19
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
14 changes: 8 additions & 6 deletions src/ragged/_spec_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,16 +413,18 @@ def ceil(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.ceil.html
"""
def _wrapper(dtype):
if dtype in [np.int8, np.uint8, np.bool_, np.bool]:

def _wrapper(t: np.dtype, /) -> np.dtype:
if t in [np.int8, np.uint8, np.bool_, bool]:
return np.float16
elif dtype in [np.int16, np.uint16]:
elif t in [np.int16, np.uint16]:
return np.float32
elif dtype in [np.int32, np.uint32, np.int64, np.uint64]:
elif t in [np.int32, np.uint32, np.int64, np.uint64]:
return np.float64
else:
return dtype
return _box(type(x), np.ceil(*_unbox(x)), _wrapper(x.dtype))
return t

return _box(type(x), np.ceil(*_unbox(x)), dtype=_wrapper(x.dtype))


def conj(x: array, /) -> array:
Expand Down
19 changes: 15 additions & 4 deletions tests/test_spec_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ def first(x: ragged.array) -> Any:
return xp.asarray(out.item(), dtype=x.dtype)


def _wrapper(t: np.dtype, /) -> np.dtype:
if t in [np.int8, np.uint8, np.bool_, bool]:
return np.float16
elif t in [np.int16, np.uint16]:
return np.float32
elif t in [np.int32, np.uint32, np.int64, np.uint64]:
return np.float64
else:
return t


def test_existence():
assert ragged.abs is not None
assert ragged.acos is not None
Expand Down Expand Up @@ -394,8 +405,8 @@ def test_ceil_int(device, x_int):
result = ragged.ceil(x_int.to_device(device))
assert type(result) is type(x_int)
assert result.shape == x_int.shape
assert xp.ceil(first(x_int)) == first(result)
assert xp.ceil(first(x_int)).dtype == result.dtype
assert xp.ceil(first(x_int)) == first(result).astype(_wrapper(first(result).dtype))
assert xp.ceil(first(x_int)).dtype == _wrapper(result.dtype)


@pytest.mark.skipif(
Expand Down Expand Up @@ -507,8 +518,8 @@ def test_floor_int(device, x_int):
result = ragged.floor(x_int.to_device(device))
assert type(result) is type(x_int)
assert result.shape == x_int.shape
assert xp.floor(first(x_int)) == first(result)
assert xp.floor(first(x_int)).dtype == result.dtype
assert xp.floor(first(x_int)) == first(result).astype(_wrapper(first(result).dtype))
assert xp.floor(first(x_int)).dtype == _wrapper(result.dtype)


@pytest.mark.parametrize("device", devices)
Expand Down

0 comments on commit 32eae19

Please sign in to comment.