Skip to content

Commit

Permalink
Modularize stream validation into a function.
Browse files Browse the repository at this point in the history
Stream keyword validation and deployment was copy-pasted in several
places. Created function _stream_validate_and_use, and used it in
a couple of places.

This brings uniformity of error messages, and should improve coverage
and maintainability.
  • Loading branch information
oleksandr-pavlyk committed Jan 16, 2025
1 parent df8c1f3 commit 74066bb
Showing 1 changed file with 18 additions and 39 deletions.
57 changes: 18 additions & 39 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,19 @@ cdef bint _is_host_cpu(object dl_device):
return (dl_type == DLDeviceType.kDLCPU) and (dl_id == 0)


cdef void _validate_and_use_stream(object stream, c_dpctl.SyclQueue self_queue) except *:
if (stream is None or stream == self_queue):
pass
else:
if not isinstance(stream, dpctl.SyclQueue):
raise TypeError(
"stream argument type was expected to be dpctl.SyclQueue,"
f" got {type(stream)} instead"
)
ev = self_queue.submit_barrier()
stream.submit_barrier(dependent_events=[ev])


cdef class usm_ndarray:
""" usm_ndarray(shape, dtype=None, strides=None, buffer="device", \
offset=0, order="C", buffer_ctor_kwargs=dict(), \
Expand Down Expand Up @@ -1025,16 +1038,7 @@ cdef class usm_ndarray:
cdef c_dpmem._Memory arr_buf
d = Device.create_device(target_device)

if (stream is None or stream == self.sycl_queue):
pass
else:
if not isinstance(stream, dpctl.SyclQueue):
raise TypeError(
"stream argument type was expected to be dpctl.SyclQueue,"
f" got {type(stream)} instead"
)
ev = self.sycl_queue.submit_barrier()
stream.submit_barrier(dependent_events=[ev])
_validate_and_use_stream(stream, self.sycl_queue)

if (d.sycl_context == self.sycl_context):
arr_buf = <c_dpmem._Memory> self.usm_data
Expand Down Expand Up @@ -1207,17 +1211,7 @@ cdef class usm_ndarray:
# legacy path for DLManagedTensor
# copy kwarg ignored because copy flag can't be set
_caps = c_dlpack.to_dlpack_capsule(self)
if (stream is None or stream == self.sycl_queue):
pass
else:
if not isinstance(stream, dpctl.SyclQueue):
raise TypeError(
"stream keyword argument type is expected to "
"be dpctl.SyclQueue, "
f" got {type(stream)} instead"
)
ev = self.sycl_queue.submit_barrier()
stream.submit_barrier(dependent_events=[ev])
_validate_and_use_stream(stream, self.sycl_queue)
return _caps
else:
if not isinstance(max_version, tuple) or len(max_version) != 2:
Expand Down Expand Up @@ -1259,12 +1253,7 @@ cdef class usm_ndarray:
copy = False
# TODO: strategy for handling stream on different device from dl_device
if copy:
if (stream is None or type(stream) is not dpctl.SyclQueue or
stream == self.sycl_queue):
pass
else:
ev = self.sycl_queue.submit_barrier()
stream.submit_barrier(dependent_events=[ev])
_validate_and_use_stream(stream, self.sycl_queue)
nbytes = self.usm_data.nbytes
copy_buffer = type(self.usm_data)(
nbytes, queue=self.sycl_queue
Expand All @@ -1281,22 +1270,12 @@ cdef class usm_ndarray:
_caps = c_dlpack.to_dlpack_versioned_capsule(_copied_arr, copy)
else:
_caps = c_dlpack.to_dlpack_versioned_capsule(self, copy)
if (stream is None or type(stream) is not dpctl.SyclQueue or
stream == self.sycl_queue):
pass
else:
ev = self.sycl_queue.submit_barrier()
stream.submit_barrier(dependent_events=[ev])
_validate_and_use_stream(stream, self.sycl_queue)
return _caps
else:
# legacy path for DLManagedTensor
_caps = c_dlpack.to_dlpack_capsule(self)
if (stream is None or type(stream) is not dpctl.SyclQueue or
stream == self.sycl_queue):
pass
else:
ev = self.sycl_queue.submit_barrier()
stream.submit_barrier(dependent_events=[ev])
_validate_and_use_stream(stream, self.sycl_queue)
return _caps

def __dlpack_device__(self):
Expand Down

0 comments on commit 74066bb

Please sign in to comment.