Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor asarray support for usm ndarray protocol #1959

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Extended `dpctl.SyclTimer` with `device_timer` keyword, implementing different methods of collecting device times [gh-1872](https://github.com/IntelPython/dpctl/pull/1872)
* Improved performance of `tensor.cumulative_sum`, `tensor.cumulative_prod`, `tensor.cumulative_logsumexp` as well as performance of boolean indexing [gh-1923](https://github.com/IntelPython/dpctl/pull/1923)
* Improved performance of `tensor.min`, `tensor.max`, `tensor.logsumexp`, `tensor.reduce_hypot` for floating point type arrays by at least 2x [gh-1932](https://github.com/IntelPython/dpctl/pull/1932)
* Extended `tensor.asarray` to support objects that implement `__usm_ndarray__` property to be interpreted as `usm_ndarray` objects [gh-1959](https://github.com/IntelPython/dpctl/pull/1959)

### Fixed

Expand Down
30 changes: 30 additions & 0 deletions dpctl/tensor/_ctors.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def _array_info_dispatch(obj):
if _is_object_with_buffer_protocol(obj):
np_obj = np.array(obj)
return np_obj.shape, np_obj.dtype, _host_set
if hasattr(obj, "__usm_ndarray__"):
usm_ar = getattr(obj, "__usm_ndarray__")
if isinstance(usm_ar, dpt.usm_ndarray):
return usm_ar.shape, usm_ar.dtype, frozenset([usm_ar.sycl_queue])
if hasattr(obj, "__sycl_usm_array_interface__"):
usm_ar = _usm_ndarray_from_suai(obj)
return usm_ar.shape, usm_ar.dtype, frozenset([usm_ar.sycl_queue])
Expand Down Expand Up @@ -306,6 +310,11 @@ def _usm_types_walker(o, usm_types_list):
if isinstance(o, dpt.usm_ndarray):
usm_types_list.append(o.usm_type)
return
if hasattr(o, "__usm_ndarray__"):
usm_arr = getattr(o, "__usm_ndarray__")
if isinstance(usm_arr, dpt.usm_ndarray):
usm_types_list.append(usm_arr.usm_type)
return
if hasattr(o, "__sycl_usm_array_interface__"):
usm_ar = _usm_ndarray_from_suai(o)
usm_types_list.append(usm_ar.usm_type)
Expand All @@ -330,6 +339,11 @@ def _device_copy_walker(seq_o, res, _manager):
)
_manager.add_event_pair(ht_ev, cpy_ev)
return
if hasattr(seq_o, "__usm_ndarray__"):
usm_arr = getattr(seq_o, "__usm_ndarray__")
if isinstance(usm_arr, dpt.usm_ndarray):
_device_copy_walker(usm_arr, res, _manager)
return
if hasattr(seq_o, "__sycl_usm_array_interface__"):
usm_ar = _usm_ndarray_from_suai(seq_o)
exec_q = res.sycl_queue
Expand Down Expand Up @@ -361,6 +375,11 @@ def _copy_through_host_walker(seq_o, usm_res):
return
else:
usm_res[...] = seq_o
if hasattr(seq_o, "__usm_ndarray__"):
usm_arr = getattr(seq_o, "__usm_ndarray__")
if isinstance(usm_arr, dpt.usm_ndarray):
_copy_through_host_walker(usm_arr, usm_res)
return
if hasattr(seq_o, "__sycl_usm_array_interface__"):
usm_ar = _usm_ndarray_from_suai(seq_o)
if (
Expand Down Expand Up @@ -564,6 +583,17 @@ def asarray(
sycl_queue=sycl_queue,
order=order,
)
if hasattr(obj, "__usm_ndarray__"):
usm_arr = getattr(obj, "__usm_ndarray__")
if isinstance(usm_arr, dpt.usm_ndarray):
return _asarray_from_usm_ndarray(
usm_arr,
dtype=dtype,
copy=copy,
usm_type=usm_type,
sycl_queue=sycl_queue,
order=order,
)
if hasattr(obj, "__sycl_usm_array_interface__"):
ary = _usm_ndarray_from_suai(obj)
return _asarray_from_usm_ndarray(
Expand Down
67 changes: 67 additions & 0 deletions dpctl/tests/test_tensor_asarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,3 +556,70 @@ def test_as_f_contig_square(dt):
x3 = dpt.flip(x, axis=1)
y3 = dpt.asarray(x3, order="F")
assert dpt.all(x3 == y3)


class MockArrayWithBothProtocols:
"""
Object that implements both __sycl_usm_array_interface__
and __usm_ndarray__ properties.
"""

def __init__(self, usm_ar):
if not isinstance(usm_ar, dpt.usm_ndarray):
raise TypeError
self._arr = usm_ar

@property
def __usm_ndarray__(self):
return self._arr

@property
def __sycl_usm_array_interface__(self):
return self._arr.__sycl_usm_array_interface__


class MockArrayWithSUAIOnly:
"""
Object that implements only the
__sycl_usm_array_interface__ property.
"""

def __init__(self, usm_ar):
if not isinstance(usm_ar, dpt.usm_ndarray):
raise TypeError
self._arr = usm_ar

@property
def __sycl_usm_array_interface__(self):
return self._arr.__sycl_usm_array_interface__


@pytest.mark.parametrize("usm_type", ["shared", "device", "host"])
def test_asarray_support_for_usm_ndarray_protocol(usm_type):
get_queue_or_skip()

x = dpt.arange(256, dtype="i4", usm_type=usm_type)

o1 = MockArrayWithBothProtocols(x)
o2 = MockArrayWithSUAIOnly(x)

y1 = dpt.asarray(o1)
assert x.sycl_queue == y1.sycl_queue
assert x.usm_type == y1.usm_type
assert x.dtype == y1.dtype
assert y1.usm_data.reference_obj is None
assert dpt.all(x == y1)

y2 = dpt.asarray(o2)
assert x.sycl_queue == y2.sycl_queue
assert x.usm_type == y2.usm_type
assert x.dtype == y2.dtype
assert not (y2.usm_data.reference_obj is None)
assert dpt.all(x == y2)

y3 = dpt.asarray([o1, o2])
assert x.sycl_queue == y3.sycl_queue
assert x.usm_type == y3.usm_type
assert x.dtype == y3.dtype
assert y3.usm_data.reference_obj is None
assert dpt.all(x[dpt.newaxis, :] == y3)
Loading