diff --git a/CHANGELOG.md b/CHANGELOG.md index 12f20edeff..f1c0133e63 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Extended `dpctl.SyclTimer` with `device_timer` keyword, implementing different methods of collecting device times [gh-1872](https://github.com/IntelPython/dpctl/pull/1872) * Improved performance of `tensor.cumulative_sum`, `tensor.cumulative_prod`, `tensor.cumulative_logsumexp` as well as performance of boolean indexing [gh-1923](https://github.com/IntelPython/dpctl/pull/1923) * Improved performance of `tensor.min`, `tensor.max`, `tensor.logsumexp`, `tensor.reduce_hypot` for floating point type arrays by at least 2x [gh-1932](https://github.com/IntelPython/dpctl/pull/1932) +* Extended `tensor.asarray` to support objects that implement `__usm_ndarray__` property to be interpreted as `usm_ndarray` objects [gh-1959](https://github.com/IntelPython/dpctl/pull/1959) ### Fixed diff --git a/dpctl/tensor/_ctors.py b/dpctl/tensor/_ctors.py index 33ebe1be90..ecdba971e2 100644 --- a/dpctl/tensor/_ctors.py +++ b/dpctl/tensor/_ctors.py @@ -61,6 +61,10 @@ def _array_info_dispatch(obj): if _is_object_with_buffer_protocol(obj): np_obj = np.array(obj) return np_obj.shape, np_obj.dtype, _host_set + if hasattr(obj, "__usm_ndarray__"): + usm_ar = getattr(obj, "__usm_ndarray__") + if isinstance(usm_ar, dpt.usm_ndarray): + return usm_ar.shape, usm_ar.dtype, frozenset([usm_ar.sycl_queue]) if hasattr(obj, "__sycl_usm_array_interface__"): usm_ar = _usm_ndarray_from_suai(obj) return usm_ar.shape, usm_ar.dtype, frozenset([usm_ar.sycl_queue]) @@ -306,6 +310,11 @@ def _usm_types_walker(o, usm_types_list): if isinstance(o, dpt.usm_ndarray): usm_types_list.append(o.usm_type) return + if hasattr(o, "__usm_ndarray__"): + usm_arr = getattr(o, "__usm_ndarray__") + if isinstance(usm_arr, dpt.usm_ndarray): + usm_types_list.append(usm_arr.usm_type) + return if hasattr(o, "__sycl_usm_array_interface__"): usm_ar = _usm_ndarray_from_suai(o) usm_types_list.append(usm_ar.usm_type) @@ -330,6 +339,11 @@ def _device_copy_walker(seq_o, res, _manager): ) _manager.add_event_pair(ht_ev, cpy_ev) return + if hasattr(seq_o, "__usm_ndarray__"): + usm_arr = getattr(seq_o, "__usm_ndarray__") + if isinstance(usm_arr, dpt.usm_ndarray): + _device_copy_walker(usm_arr, res, _manager) + return if hasattr(seq_o, "__sycl_usm_array_interface__"): usm_ar = _usm_ndarray_from_suai(seq_o) exec_q = res.sycl_queue @@ -361,6 +375,11 @@ def _copy_through_host_walker(seq_o, usm_res): return else: usm_res[...] = seq_o + if hasattr(seq_o, "__usm_ndarray__"): + usm_arr = getattr(seq_o, "__usm_ndarray__") + if isinstance(usm_arr, dpt.usm_ndarray): + _copy_through_host_walker(usm_arr, usm_res) + return if hasattr(seq_o, "__sycl_usm_array_interface__"): usm_ar = _usm_ndarray_from_suai(seq_o) if ( @@ -564,6 +583,17 @@ def asarray( sycl_queue=sycl_queue, order=order, ) + if hasattr(obj, "__usm_ndarray__"): + usm_arr = getattr(obj, "__usm_ndarray__") + if isinstance(usm_arr, dpt.usm_ndarray): + return _asarray_from_usm_ndarray( + usm_arr, + dtype=dtype, + copy=copy, + usm_type=usm_type, + sycl_queue=sycl_queue, + order=order, + ) if hasattr(obj, "__sycl_usm_array_interface__"): ary = _usm_ndarray_from_suai(obj) return _asarray_from_usm_ndarray( diff --git a/dpctl/tests/test_tensor_asarray.py b/dpctl/tests/test_tensor_asarray.py index 9ac30b404f..20e2ddd704 100644 --- a/dpctl/tests/test_tensor_asarray.py +++ b/dpctl/tests/test_tensor_asarray.py @@ -556,3 +556,70 @@ def test_as_f_contig_square(dt): x3 = dpt.flip(x, axis=1) y3 = dpt.asarray(x3, order="F") assert dpt.all(x3 == y3) + + +class MockArrayWithBothProtocols: + """ + Object that implements both __sycl_usm_array_interface__ + and __usm_ndarray__ properties. + """ + + def __init__(self, usm_ar): + if not isinstance(usm_ar, dpt.usm_ndarray): + raise TypeError + self._arr = usm_ar + + @property + def __usm_ndarray__(self): + return self._arr + + @property + def __sycl_usm_array_interface__(self): + return self._arr.__sycl_usm_array_interface__ + + +class MockArrayWithSUAIOnly: + """ + Object that implements only the + __sycl_usm_array_interface__ property. + """ + + def __init__(self, usm_ar): + if not isinstance(usm_ar, dpt.usm_ndarray): + raise TypeError + self._arr = usm_ar + + @property + def __sycl_usm_array_interface__(self): + return self._arr.__sycl_usm_array_interface__ + + +@pytest.mark.parametrize("usm_type", ["shared", "device", "host"]) +def test_asarray_support_for_usm_ndarray_protocol(usm_type): + get_queue_or_skip() + + x = dpt.arange(256, dtype="i4", usm_type=usm_type) + + o1 = MockArrayWithBothProtocols(x) + o2 = MockArrayWithSUAIOnly(x) + + y1 = dpt.asarray(o1) + assert x.sycl_queue == y1.sycl_queue + assert x.usm_type == y1.usm_type + assert x.dtype == y1.dtype + assert y1.usm_data.reference_obj is None + assert dpt.all(x == y1) + + y2 = dpt.asarray(o2) + assert x.sycl_queue == y2.sycl_queue + assert x.usm_type == y2.usm_type + assert x.dtype == y2.dtype + assert not (y2.usm_data.reference_obj is None) + assert dpt.all(x == y2) + + y3 = dpt.asarray([o1, o2]) + assert x.sycl_queue == y3.sycl_queue + assert x.usm_type == y3.usm_type + assert x.dtype == y3.dtype + assert y3.usm_data.reference_obj is None + assert dpt.all(x[dpt.newaxis, :] == y3)