Skip to content

Commit

Permalink
Merge pull request #1959 from IntelPython/tensor-asarray-support-for-…
Browse files Browse the repository at this point in the history
…usm-ndarray-protocol

Tensor asarray support for usm ndarray protocol
  • Loading branch information
oleksandr-pavlyk authored Jan 14, 2025
2 parents c354cd8 + e8fe0e0 commit 0f3536b
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Extended `dpctl.SyclTimer` with `device_timer` keyword, implementing different methods of collecting device times [gh-1872](https://github.com/IntelPython/dpctl/pull/1872)
* Improved performance of `tensor.cumulative_sum`, `tensor.cumulative_prod`, `tensor.cumulative_logsumexp` as well as performance of boolean indexing [gh-1923](https://github.com/IntelPython/dpctl/pull/1923)
* Improved performance of `tensor.min`, `tensor.max`, `tensor.logsumexp`, `tensor.reduce_hypot` for floating point type arrays by at least 2x [gh-1932](https://github.com/IntelPython/dpctl/pull/1932)
* Extended `tensor.asarray` to support objects that implement `__usm_ndarray__` property to be interpreted as `usm_ndarray` objects [gh-1959](https://github.com/IntelPython/dpctl/pull/1959)

### Fixed

Expand Down
30 changes: 30 additions & 0 deletions dpctl/tensor/_ctors.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def _array_info_dispatch(obj):
if _is_object_with_buffer_protocol(obj):
np_obj = np.array(obj)
return np_obj.shape, np_obj.dtype, _host_set
if hasattr(obj, "__usm_ndarray__"):
usm_ar = getattr(obj, "__usm_ndarray__")
if isinstance(usm_ar, dpt.usm_ndarray):
return usm_ar.shape, usm_ar.dtype, frozenset([usm_ar.sycl_queue])
if hasattr(obj, "__sycl_usm_array_interface__"):
usm_ar = _usm_ndarray_from_suai(obj)
return usm_ar.shape, usm_ar.dtype, frozenset([usm_ar.sycl_queue])
Expand Down Expand Up @@ -306,6 +310,11 @@ def _usm_types_walker(o, usm_types_list):
if isinstance(o, dpt.usm_ndarray):
usm_types_list.append(o.usm_type)
return
if hasattr(o, "__usm_ndarray__"):
usm_arr = getattr(o, "__usm_ndarray__")
if isinstance(usm_arr, dpt.usm_ndarray):
usm_types_list.append(usm_arr.usm_type)
return
if hasattr(o, "__sycl_usm_array_interface__"):
usm_ar = _usm_ndarray_from_suai(o)
usm_types_list.append(usm_ar.usm_type)
Expand All @@ -330,6 +339,11 @@ def _device_copy_walker(seq_o, res, _manager):
)
_manager.add_event_pair(ht_ev, cpy_ev)
return
if hasattr(seq_o, "__usm_ndarray__"):
usm_arr = getattr(seq_o, "__usm_ndarray__")
if isinstance(usm_arr, dpt.usm_ndarray):
_device_copy_walker(usm_arr, res, _manager)
return
if hasattr(seq_o, "__sycl_usm_array_interface__"):
usm_ar = _usm_ndarray_from_suai(seq_o)
exec_q = res.sycl_queue
Expand Down Expand Up @@ -361,6 +375,11 @@ def _copy_through_host_walker(seq_o, usm_res):
return
else:
usm_res[...] = seq_o
if hasattr(seq_o, "__usm_ndarray__"):
usm_arr = getattr(seq_o, "__usm_ndarray__")
if isinstance(usm_arr, dpt.usm_ndarray):
_copy_through_host_walker(usm_arr, usm_res)
return
if hasattr(seq_o, "__sycl_usm_array_interface__"):
usm_ar = _usm_ndarray_from_suai(seq_o)
if (
Expand Down Expand Up @@ -564,6 +583,17 @@ def asarray(
sycl_queue=sycl_queue,
order=order,
)
if hasattr(obj, "__usm_ndarray__"):
usm_arr = getattr(obj, "__usm_ndarray__")
if isinstance(usm_arr, dpt.usm_ndarray):
return _asarray_from_usm_ndarray(
usm_arr,
dtype=dtype,
copy=copy,
usm_type=usm_type,
sycl_queue=sycl_queue,
order=order,
)
if hasattr(obj, "__sycl_usm_array_interface__"):
ary = _usm_ndarray_from_suai(obj)
return _asarray_from_usm_ndarray(
Expand Down
67 changes: 67 additions & 0 deletions dpctl/tests/test_tensor_asarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,3 +556,70 @@ def test_as_f_contig_square(dt):
x3 = dpt.flip(x, axis=1)
y3 = dpt.asarray(x3, order="F")
assert dpt.all(x3 == y3)


class MockArrayWithBothProtocols:
"""
Object that implements both __sycl_usm_array_interface__
and __usm_ndarray__ properties.
"""

def __init__(self, usm_ar):
if not isinstance(usm_ar, dpt.usm_ndarray):
raise TypeError
self._arr = usm_ar

@property
def __usm_ndarray__(self):
return self._arr

@property
def __sycl_usm_array_interface__(self):
return self._arr.__sycl_usm_array_interface__


class MockArrayWithSUAIOnly:
"""
Object that implements only the
__sycl_usm_array_interface__ property.
"""

def __init__(self, usm_ar):
if not isinstance(usm_ar, dpt.usm_ndarray):
raise TypeError
self._arr = usm_ar

@property
def __sycl_usm_array_interface__(self):
return self._arr.__sycl_usm_array_interface__


@pytest.mark.parametrize("usm_type", ["shared", "device", "host"])
def test_asarray_support_for_usm_ndarray_protocol(usm_type):
get_queue_or_skip()

x = dpt.arange(256, dtype="i4", usm_type=usm_type)

o1 = MockArrayWithBothProtocols(x)
o2 = MockArrayWithSUAIOnly(x)

y1 = dpt.asarray(o1)
assert x.sycl_queue == y1.sycl_queue
assert x.usm_type == y1.usm_type
assert x.dtype == y1.dtype
assert y1.usm_data.reference_obj is None
assert dpt.all(x == y1)

y2 = dpt.asarray(o2)
assert x.sycl_queue == y2.sycl_queue
assert x.usm_type == y2.usm_type
assert x.dtype == y2.dtype
assert not (y2.usm_data.reference_obj is None)
assert dpt.all(x == y2)

y3 = dpt.asarray([o1, o2])
assert x.sycl_queue == y3.sycl_queue
assert x.usm_type == y3.usm_type
assert x.dtype == y3.dtype
assert y3.usm_data.reference_obj is None
assert dpt.all(x[dpt.newaxis, :] == y3)

0 comments on commit 0f3536b

Please sign in to comment.