diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index ba18600135..5b394d971b 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -182,13 +182,20 @@ cdef class usm_ndarray: cdef bint is_fp16 = False self._reset() - if (not isinstance(shape, (list, tuple)) - and not hasattr(shape, 'tolist')): - try: - shape - shape = [shape, ] - except Exception: - raise TypeError("Argument shape must be a list or a tuple.") + if not isinstance(shape, (list, tuple)): + if hasattr(shape, 'tolist'): + fn = getattr(shape, 'tolist') + if callable(fn): + shape = shape.tolist() + if not isinstance(shape, (list, tuple)): + try: + shape + shape = [shape, ] + except Exception as e: + raise TypeError( + "Argument shape must a non-negative integer, " + "or a list/tuple of such integers." + ) from e nd = len(shape) if dtype is None: if isinstance(buffer, (dpmem._memory._Memory, usm_ndarray)): diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index 72f5aabebb..095bbc5638 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -39,6 +39,7 @@ (2, 5, 2), (2, 2, 2, 2, 2, 2, 2, 2), 5, + np.int32(7), ], ) @pytest.mark.parametrize("usm_type", ["shared", "host", "device"]) diff --git a/dpctl/tests/test_usm_ndarray_reductions.py b/dpctl/tests/test_usm_ndarray_reductions.py index cbfd6baec6..0969822e6d 100644 --- a/dpctl/tests/test_usm_ndarray_reductions.py +++ b/dpctl/tests/test_usm_ndarray_reductions.py @@ -175,9 +175,11 @@ def test_search_reduction_kernels(arg_dtype): q = get_queue_or_skip() skip_if_dtype_not_supported(arg_dtype, q) - x = dpt.ones((24 * 1025), dtype=arg_dtype, sycl_queue=q) + x_shape = (24, 1024) + x_size = np.prod(x_shape) + x = dpt.ones(x_size, dtype=arg_dtype, sycl_queue=q) idx = randrange(x.size) - idx_tup = np.unravel_index(idx, (24, 1025)) + idx_tup = np.unravel_index(idx, x_shape) x[idx] = 2 m = dpt.argmax(x) @@ -194,7 +196,7 @@ def test_search_reduction_kernels(arg_dtype): m = dpt.argmax(y) assert m == 2 * idx - x = dpt.reshape(x, (24, 1025)) + x = dpt.reshape(x, x_shape) x[idx_tup[0], :] = 3 m = dpt.argmax(x, axis=0) @@ -209,15 +211,15 @@ def test_search_reduction_kernels(arg_dtype): m = dpt.argmax(x, axis=1) assert dpt.all(m == idx) - x = dpt.ones((24 * 1025), dtype=arg_dtype, sycl_queue=q) + x = dpt.ones(x_size, dtype=arg_dtype, sycl_queue=q) idx = randrange(x.size) - idx_tup = np.unravel_index(idx, (24, 1025)) + idx_tup = np.unravel_index(idx, x_shape) x[idx] = 0 m = dpt.argmin(x) assert m == idx - x = dpt.reshape(x, (24, 1025)) + x = dpt.reshape(x, x_shape) x[idx_tup[0], :] = -1 m = dpt.argmin(x, axis=0)