Skip to content

Commit

Permalink
Implement __usm_ndarray__ protocol (#2261)
Browse files Browse the repository at this point in the history
The PR is intended to adopt to dpctl changes implemented in
[dpctl#1959](IntelPython/dpctl#1959).

It implements support of `__usm_ndarray__` protocol for `dpnp.ndarray`
and returns a property with `dpctl.tensor.usm_ndarray` instance
corresponding to the content of the array object.

This property is intended to speed-up conversion from `dpnp.ndarray` to
`dpt.usm_ndarray` in `x=dpt.asarray(dpnp_array_obj)`.
The input object that implements `__usm_ndarray__` is recognized as
owner of USM allocation that is managed by a smart pointer, and
asynchronous deallocation of `x` need not involve GIL.
  • Loading branch information
antonwolfy authored Jan 20, 2025
1 parent 6cc2348 commit 9ad1bb5
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
19 changes: 19 additions & 0 deletions dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,25 @@ def __truediv__(self, other):
"""Return ``self/value``."""
return dpnp.true_divide(self, other)

@property
def __usm_ndarray__(self):
"""
Property to support `__usm_ndarray__` protocol.
It assumes to return :class:`dpctl.tensor.usm_ndarray` instance
corresponding to the content of the object.
This property is intended to speed-up conversion from
:class:`dpnp.ndarray` to :class:`dpctl.tensor.usm_ndarray` passed
into `dpctl.tensor.asarray` function. The input object that implements
`__usm_ndarray__` protocol is recognized as owner of USM allocation
that is managed by a smart pointer, and asynchronous deallocation
will not involve GIL.
"""

return self._array_obj

def __xor__(self, other):
"""Return ``self^value``."""
return dpnp.bitwise_xor(self, other)
Expand Down
12 changes: 12 additions & 0 deletions dpnp/tests/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,18 @@ def test_error(self):
ia.item()


class TestUsmNdarrayProtocol:
def test_basic(self):
a = dpnp.arange(256, dtype=dpnp.int64)
usm_a = dpt.asarray(a)

assert a.sycl_queue == usm_a.sycl_queue
assert a.usm_type == usm_a.usm_type
assert a.dtype == usm_a.dtype
assert usm_a.usm_data.reference_obj is None
assert (a == usm_a).all()


def test_print_dpnp_int():
result = repr(dpnp.array([1, 0, 2, -3, -1, 2, 21, -9], dtype="i4"))
expected = "array([ 1, 0, 2, -3, -1, 2, 21, -9], dtype=int32)"
Expand Down

0 comments on commit 9ad1bb5

Please sign in to comment.