From cb4fd574395a3b4adf91721d8992e24cbbd5e6af Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Thu, 9 Jan 2025 10:25:59 -0800 Subject: [PATCH] Add tests for dldevice and sycldevice interchange functions --- dpctl/tests/test_usm_ndarray_dlpack.py | 35 ++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_dlpack.py b/dpctl/tests/test_usm_ndarray_dlpack.py index 1fea8ec11e..6b15edfdfc 100644 --- a/dpctl/tests/test_usm_ndarray_dlpack.py +++ b/dpctl/tests/test_usm_ndarray_dlpack.py @@ -826,3 +826,38 @@ def test_generic_container(): assert isinstance(Z, dpt.usm_ndarray) assert Z._pointer == X._pointer assert Z.device == X.device + + +def test_sycldevice_to_dldevice(all_root_devices): + for sycl_dev in all_root_devices: + dev = dpt.sycldevice_to_dldevice(sycl_dev) + assert type(dev) is tuple + assert len(dev) == 2 + assert dev[0] == device_oneAPI + assert dev[1] == all_root_devices.index(sycl_dev) + + +def test_dldevice_to_sycldevice(all_root_devices): + for sycl_dev in all_root_devices: + dldev = dpt.empty(0, device=sycl_dev).__dlpack_device__() + dev = dpt.dldevice_to_sycldevice(dldev) + assert type(dev) is dpctl.SyclDevice + assert dev == all_root_devices[dldev[1]] + + +def test_dldevice_conversion_arg_validation(): + bad_dldevice_type = (dpt.DLDeviceType.kDLCPU, 0) + with pytest.raises(ValueError): + dpt.dldevice_to_sycldevice(bad_dldevice_type) + + bad_dldevice_len = bad_dldevice_type + (0,) + with pytest.raises(ValueError): + dpt.dldevice_to_sycldevice(bad_dldevice_len) + + bad_dldevice = dict() + with pytest.raises(TypeError): + dpt.dldevice_to_sycldevice(bad_dldevice) + + bad_sycldevice = dict() + with pytest.raises(TypeError): + dpt.sycldevice_to_dldevice(bad_sycldevice)