Skip to content

Commit

Permalink
Add tests for dldevice and sycldevice interchange functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ndgrigorian committed Jan 9, 2025
1 parent c529b29 commit cb4fd57
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions dpctl/tests/test_usm_ndarray_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,3 +826,38 @@ def test_generic_container():
assert isinstance(Z, dpt.usm_ndarray)
assert Z._pointer == X._pointer
assert Z.device == X.device


def test_sycldevice_to_dldevice(all_root_devices):
for sycl_dev in all_root_devices:
dev = dpt.sycldevice_to_dldevice(sycl_dev)
assert type(dev) is tuple
assert len(dev) == 2
assert dev[0] == device_oneAPI
assert dev[1] == all_root_devices.index(sycl_dev)


def test_dldevice_to_sycldevice(all_root_devices):
for sycl_dev in all_root_devices:
dldev = dpt.empty(0, device=sycl_dev).__dlpack_device__()
dev = dpt.dldevice_to_sycldevice(dldev)
assert type(dev) is dpctl.SyclDevice
assert dev == all_root_devices[dldev[1]]


def test_dldevice_conversion_arg_validation():
bad_dldevice_type = (dpt.DLDeviceType.kDLCPU, 0)
with pytest.raises(ValueError):
dpt.dldevice_to_sycldevice(bad_dldevice_type)

bad_dldevice_len = bad_dldevice_type + (0,)
with pytest.raises(ValueError):
dpt.dldevice_to_sycldevice(bad_dldevice_len)

bad_dldevice = dict()
with pytest.raises(TypeError):
dpt.dldevice_to_sycldevice(bad_dldevice)

bad_sycldevice = dict()
with pytest.raises(TypeError):
dpt.sycldevice_to_dldevice(bad_sycldevice)

0 comments on commit cb4fd57

Please sign in to comment.