From 7436de67a285b345ca55d936838266675f9f99f4 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk <21087696+oleksandr-pavlyk@users.noreply.github.com> Date: Sun, 29 Dec 2024 11:06:09 -0600 Subject: [PATCH 1/8] Use Ubuntu 24.04 to run the OS-LLVM workflow to build dpctl with nightly SYCL bundle DPC++ compiler --- .github/workflows/os-llvm-sycl-build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/os-llvm-sycl-build.yml b/.github/workflows/os-llvm-sycl-build.yml index 7648bbb42b..f775a543a7 100644 --- a/.github/workflows/os-llvm-sycl-build.yml +++ b/.github/workflows/os-llvm-sycl-build.yml @@ -9,7 +9,7 @@ permissions: read-all jobs: install-compiler: name: Build with nightly build of DPC++ toolchain - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 timeout-minutes: 90 env: From bff71fcec9a9c2c51e1451fcb6a1bb3bde7b1379 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk <21087696+oleksandr-pavlyk@users.noreply.github.com> Date: Sun, 29 Dec 2024 11:18:01 -0600 Subject: [PATCH 2/8] No need to install libtinfo5 on Ubuntu 24.04 --- .github/workflows/os-llvm-sycl-build.yml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.github/workflows/os-llvm-sycl-build.yml b/.github/workflows/os-llvm-sycl-build.yml index f775a543a7..74833ec3f9 100644 --- a/.github/workflows/os-llvm-sycl-build.yml +++ b/.github/workflows/os-llvm-sycl-build.yml @@ -93,11 +93,6 @@ jobs: cp oclcpuexp/x64/libOpenCL.so* dpcpp_compiler/lib/ fi - - name: Install system components - shell: bash -l {0} - run: | - sudo apt-get install libtinfo5 - - name: Setup Python uses: actions/setup-python@v5 with: From addb341a6aae56848517365c6f7a96d6e828cb44 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk <21087696+oleksandr-pavlyk@users.noreply.github.com> Date: Tue, 31 Dec 2024 10:15:11 -0600 Subject: [PATCH 3/8] Fixed blunder in work-item id to data_id computation gid-lane_id is already a multiple of sg_size. --- dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp index 0ff3fc4723..1828c88eb4 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/topk.hpp @@ -101,8 +101,7 @@ sycl::event write_out_impl(sycl::queue &exec_q, const std::uint32_t lane_id = sg.get_local_id()[0]; const std::uint32_t sg_size = sg.get_max_local_range()[0]; - const std::size_t start_id = - (gid - lane_id) * sg_size * n_wi + lane_id; + const std::size_t start_id = (gid - lane_id) * n_wi + lane_id; #pragma unroll for (std::uint32_t i = 0; i < n_wi; ++i) { From 399cdd1b47e064a75a896989272ffafa22cd8549 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk <21087696+oleksandr-pavlyk@users.noreply.github.com> Date: Tue, 31 Dec 2024 10:16:11 -0600 Subject: [PATCH 4/8] Replace map_back_impl in sort_utils Change kernel to process few data elements in the work-item. --- .../include/kernels/sorting/sort_utils.hpp | 31 ++++++++++++++++--- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/sort_utils.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/sort_utils.hpp index f62a6c3fa0..a81f528852 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/sort_utils.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/sort_utils.hpp @@ -105,14 +105,35 @@ sycl::event map_back_impl(sycl::queue &exec_q, std::size_t row_size, const std::vector &dependent_events) { + constexpr std::uint32_t lws = 64; + constexpr std::uint32_t n_wi = 4; + const std::size_t n_groups = (nelems + lws * n_wi - 1) / (n_wi * lws); + + sycl::range<1> lRange{lws}; + sycl::range<1> gRange{n_groups * lws}; + sycl::nd_range<1> ndRange{gRange, lRange}; + sycl::event map_back_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(dependent_events); - cgh.parallel_for( - sycl::range<1>(nelems), [=](sycl::id<1> id) { - const IndexTy linear_index = flat_index_data[id]; - reduced_index_data[id] = (linear_index % row_size); - }); + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> it) { + const std::size_t gid = it.get_global_linear_id(); + const auto &sg = it.get_sub_group(); + const std::uint32_t lane_id = sg.get_local_id()[0]; + const std::uint32_t sg_size = sg.get_max_local_range()[0]; + + const std::size_t start_id = (gid - lane_id) * n_wi + lane_id; + +#pragma unroll + for (std::uint32_t i = 0; i < n_wi; ++i) { + const std::size_t data_id = start_id + i * sg_size; + + if (data_id < nelems) { + const IndexTy linear_index = flat_index_data[data_id]; + reduced_index_data[data_id] = (linear_index % row_size); + } + } + }); }); return map_back_ev; From 1a6401464544b1a12ce2e31c4da0e60c5b7ab586 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk <21087696+oleksandr-pavlyk@users.noreply.github.com> Date: Thu, 2 Jan 2025 10:40:26 -0600 Subject: [PATCH 5/8] Use get_global_linear_id instead of get_global_id and rely on implicit conversion --- dpctl/tensor/libtensor/include/kernels/sorting/sort_utils.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/sort_utils.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/sort_utils.hpp index a81f528852..d1f166f945 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/sort_utils.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/sort_utils.hpp @@ -58,7 +58,7 @@ sycl::event iota_impl(sycl::queue &exec_q, sycl::event e = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(dependent_events); cgh.parallel_for(ndRange, [=](sycl::nd_item<1> it) { - const std::size_t gid = it.get_global_id(); + const std::size_t gid = it.get_global_linear_id(); const auto &sg = it.get_sub_group(); const std::uint32_t lane_id = sg.get_local_id()[0]; From 3bd333871408f3d42cc77b56dd0fa5570c4d9824 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk <21087696+oleksandr-pavlyk@users.noreply.github.com> Date: Thu, 2 Jan 2025 10:41:03 -0600 Subject: [PATCH 6/8] Counters in one-workgroup kernel to use uint16_t from uint32_t Counters can not exceed uint16_t max, because the kernel assumes that the number of elements to sort fits into uint16_t. The change reduces the kernel SLM footprint. Also, remove use of std::move, uint16_t->std::uint16_t, etc Replace size_t->std::size_t, uint32_t->std::uint32_t Use `if constexpr` in order-preservign-cast for better readability. --- .../include/kernels/sorting/radix_sort.hpp | 159 ++++++++++-------- 1 file changed, 86 insertions(+), 73 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp index 335b285bbf..dd20e647ff 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp @@ -70,13 +70,14 @@ template = 0> std::uint32_t ceil_log2(SizeT n) { + // if n > 2^b, n = q * 2^b + r for q > 0 and 0 <= r < 2^b + // floor_log2(q * 2^b + r) == floor_log2(q * 2^b) == q + floor_log2(n1) + // ceil_log2(n) == 1 + floor_log2(n-1) if (n <= 1) return std::uint32_t{1}; std::uint32_t exp{1}; --n; - // if n > 2^b, n = q * 2^b + r for q > 0 and 0 <= r < 2^b - // ceil_log2(q * 2^b + r) == ceil_log2(q * 2^b) == q + ceil_log2(n1) if (n >= (SizeT{1} << 32)) { n >>= 32; exp += 32; @@ -137,16 +138,20 @@ template order_preserving_cast(IntT val) { using UIntT = std::make_unsigned_t; - // ascending_mask: 100..0 - constexpr UIntT ascending_mask = - (UIntT(1) << std::numeric_limits::digits); - // descending_mask: 011..1 - constexpr UIntT descending_mask = (std::numeric_limits::max() >> 1); - - constexpr UIntT mask = (is_ascending) ? ascending_mask : descending_mask; const UIntT uint_val = sycl::bit_cast(val); - return (uint_val ^ mask); + if constexpr (is_ascending) { + // ascending_mask: 100..0 + constexpr UIntT ascending_mask = + (UIntT(1) << std::numeric_limits::digits); + return (uint_val ^ ascending_mask); + } + else { + // descending_mask: 011..1 + constexpr UIntT descending_mask = + (std::numeric_limits::max() >> 1); + return (uint_val ^ descending_mask); + } } template std::uint16_t order_preserving_cast(sycl::half val) @@ -1045,10 +1050,10 @@ template class radix_sort_one_wg_krn; template + std::uint16_t req_sub_group_size = (block_size < 4 ? 32 : 16)> struct subgroup_radix_sort { private: @@ -1062,8 +1067,8 @@ struct subgroup_radix_sort public: template sycl::event operator()(sycl::queue &exec_q, - size_t n_iters, - size_t n_to_sort, + std::size_t n_iters, + std::size_t n_to_sort, ValueT *input_ptr, OutputT *output_ptr, ProjT proj_op, @@ -1160,8 +1165,8 @@ struct subgroup_radix_sort }; static_assert(wg_size <= 1024); - static constexpr uint16_t bin_count = (1 << radix); - static constexpr uint16_t counter_buf_sz = wg_size * bin_count + 1; + static constexpr std::uint16_t bin_count = (1 << radix); + static constexpr std::uint16_t counter_buf_sz = wg_size * bin_count + 1; enum class temp_allocations { @@ -1177,7 +1182,7 @@ struct subgroup_radix_sort assert(n <= (SizeT(1) << 16)); constexpr auto req_slm_size_counters = - counter_buf_sz * sizeof(uint32_t); + counter_buf_sz * sizeof(std::uint16_t); const auto &dev = exec_q.get_device(); @@ -1212,9 +1217,9 @@ struct subgroup_radix_sort typename SLM_value_tag, typename SLM_counter_tag> sycl::event operator()(sycl::queue &exec_q, - size_t n_iters, - size_t n_batch_size, - size_t n_values, + std::size_t n_iters, + std::size_t n_batch_size, + std::size_t n_values, InputT *input_arr, OutputT *output_arr, const ProjT &proj_op, @@ -1228,7 +1233,7 @@ struct subgroup_radix_sort assert(n_values <= static_cast(block_size) * static_cast(wg_size)); - uint16_t n = static_cast(n_values); + const std::uint16_t n = static_cast(n_values); static_assert(std::is_same_v, OutputT>); using ValueT = OutputT; @@ -1237,17 +1242,18 @@ struct subgroup_radix_sort TempBuf buf_val( n_batch_size, static_cast(block_size * wg_size)); - TempBuf buf_count( + TempBuf buf_count( n_batch_size, static_cast(counter_buf_sz)); sycl::range<1> lRange{wg_size}; sycl::event sort_ev; - std::vector deps = depends; + std::vector deps{depends}; - std::size_t n_batches = (n_iters + n_batch_size - 1) / n_batch_size; + const std::size_t n_batches = + (n_iters + n_batch_size - 1) / n_batch_size; - for (size_t batch_id = 0; batch_id < n_batches; ++batch_id) { + for (std::size_t batch_id = 0; batch_id < n_batches; ++batch_id) { const std::size_t block_start = batch_id * n_batch_size; @@ -1286,46 +1292,49 @@ struct subgroup_radix_sort const std::size_t iter_exchange_offset = iter_id * exchange_acc_iter_stride; - uint16_t wi = ndit.get_local_linear_id(); - uint16_t begin_bit = 0; + std::uint16_t wi = ndit.get_local_linear_id(); + std::uint16_t begin_bit = 0; - constexpr uint16_t end_bit = + constexpr std::uint16_t end_bit = number_of_bits_in_type(); -// copy from input array into values + // copy from input array into values #pragma unroll - for (uint16_t i = 0; i < block_size; ++i) { - const uint16_t id = wi * block_size + i; - if (id < n) - values[i] = std::move( - this_input_arr[iter_val_offset + id]); + for (std::uint16_t i = 0; i < block_size; ++i) { + const std::uint16_t id = wi * block_size + i; + values[i] = + (id < n) ? this_input_arr[iter_val_offset + id] + : ValueT{}; } while (true) { // indices for indirect access in the "re-order" // phase - uint16_t indices[block_size]; + std::uint16_t indices[block_size]; { // pointers to bucket's counters - uint32_t *counters[block_size]; + std::uint16_t *counters[block_size]; // counting phase auto pcounter = get_accessor_pointer(counter_acc) + (wi + iter_counter_offset); -// initialize counters + // initialize counters #pragma unroll - for (uint16_t i = 0; i < bin_count; ++i) - pcounter[i * wg_size] = std::uint32_t{0}; + for (std::uint16_t i = 0; i < bin_count; ++i) + pcounter[i * wg_size] = std::uint16_t{0}; sycl::group_barrier(ndit.get_group()); if (is_ascending) { #pragma unroll - for (uint16_t i = 0; i < block_size; ++i) { - const uint16_t id = wi * block_size + i; - constexpr uint16_t bin_mask = + for (std::uint16_t i = 0; i < block_size; + ++i) + { + const std::uint16_t id = + wi * block_size + i; + constexpr std::uint16_t bin_mask = bin_count - 1; // points to the padded element, i.e. id @@ -1334,7 +1343,7 @@ struct subgroup_radix_sort default_out_of_range_bin_id = bin_mask; - const uint16_t bin = + const std::uint16_t bin = (id < n) ? get_bucket_id( order_preserving_cast< @@ -1352,9 +1361,12 @@ struct subgroup_radix_sort } else { #pragma unroll - for (uint16_t i = 0; i < block_size; ++i) { - const uint16_t id = wi * block_size + i; - constexpr uint16_t bin_mask = + for (std::uint16_t i = 0; i < block_size; + ++i) + { + const std::uint16_t id = + wi * block_size + i; + constexpr std::uint16_t bin_mask = bin_count - 1; // points to the padded element, i.e. id @@ -1363,7 +1375,7 @@ struct subgroup_radix_sort default_out_of_range_bin_id = bin_mask; - const uint16_t bin = + const std::uint16_t bin = (id < n) ? get_bucket_id( order_preserving_cast< @@ -1386,13 +1398,14 @@ struct subgroup_radix_sort { // scan contiguous numbers - uint16_t bin_sum[bin_count]; + std::uint16_t bin_sum[bin_count]; const std::size_t counter_offset0 = iter_counter_offset + wi * bin_count; bin_sum[0] = counter_acc[counter_offset0]; #pragma unroll - for (uint16_t i = 1; i < bin_count; ++i) + for (std::uint16_t i = 1; i < bin_count; + ++i) bin_sum[i] = bin_sum[i - 1] + counter_acc[counter_offset0 + i]; @@ -1400,15 +1413,16 @@ struct subgroup_radix_sort sycl::group_barrier(ndit.get_group()); // exclusive scan local sum - uint16_t sum_scan = + std::uint16_t sum_scan = sycl::exclusive_scan_over_group( ndit.get_group(), bin_sum[bin_count - 1], - sycl::plus()); + sycl::plus()); // add to local sum, generate exclusive scan result #pragma unroll - for (uint16_t i = 0; i < bin_count; ++i) + for (std::uint16_t i = 0; i < bin_count; + ++i) counter_acc[counter_offset0 + i + 1] = sum_scan + bin_sum[i]; @@ -1420,11 +1434,13 @@ struct subgroup_radix_sort } #pragma unroll - for (uint16_t i = 0; i < block_size; ++i) { + for (std::uint16_t i = 0; i < block_size; ++i) { // a global index is a local offset plus a // global base index indices[i] += *counters[i]; } + + sycl::group_barrier(ndit.get_group()); } begin_bit += radix; @@ -1432,39 +1448,36 @@ struct subgroup_radix_sort // "re-order" phase sycl::group_barrier(ndit.get_group()); if (begin_bit >= end_bit) { -// the last iteration - writing out the result + // the last iteration - writing out the result #pragma unroll - for (uint16_t i = 0; i < block_size; ++i) { - const uint16_t r = indices[i]; + for (std::uint16_t i = 0; i < block_size; ++i) { + const std::uint16_t r = indices[i]; if (r < n) { - // move the values to source range and - // destroy the values this_output_arr[iter_val_offset + r] = - std::move(values[i]); + values[i]; } } return; } -// data exchange + // data exchange #pragma unroll - for (uint16_t i = 0; i < block_size; ++i) { - const uint16_t r = indices[i]; + for (std::uint16_t i = 0; i < block_size; ++i) { + const std::uint16_t r = indices[i]; if (r < n) exchange_acc[iter_exchange_offset + r] = - std::move(values[i]); + values[i]; } sycl::group_barrier(ndit.get_group()); #pragma unroll - for (uint16_t i = 0; i < block_size; ++i) { - const uint16_t id = wi * block_size + i; + for (std::uint16_t i = 0; i < block_size; ++i) { + const std::uint16_t id = wi * block_size + i; if (id < n) - values[i] = std::move( - exchange_acc[iter_exchange_offset + - id]); + values[i] = + exchange_acc[iter_exchange_offset + id]; } sycl::group_barrier(ndit.get_group()); @@ -1736,10 +1749,10 @@ radix_sort_axis1_contig_impl(sycl::queue &exec_q, const bool sort_ascending, // number of sub-arrays to sort (num. of rows in a // matrix when sorting over rows) - size_t iter_nelems, + std::size_t iter_nelems, // size of each array to sort (length of rows, // i.e. number of columns) - size_t sort_nelems, + std::size_t sort_nelems, const char *arg_cp, char *res_cp, ssize_t iter_arg_offset, @@ -1775,10 +1788,10 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q, const bool sort_ascending, // number of sub-arrays to sort (num. of // rows in a matrix when sorting over rows) - size_t iter_nelems, + std::size_t iter_nelems, // size of each array to sort (length of // rows, i.e. number of columns) - size_t sort_nelems, + std::size_t sort_nelems, const char *arg_cp, char *res_cp, ssize_t iter_arg_offset, From 7c7d8f9970d2e42c8e4e23fb2ea8b79e59933cb8 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk <21087696+oleksandr-pavlyk@users.noreply.github.com> Date: Fri, 3 Jan 2025 19:18:45 -0600 Subject: [PATCH 7/8] Apply work-around for failing tests with CPU device and short sub-groups The team developing OpenCL:CPU device runtime and compiler was notified. See CMPLRLLVM-64592 Once fixed, the work-around should be removed. --- .../include/kernels/sorting/radix_sort.hpp | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp index dd20e647ff..15f22b334e 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp @@ -1253,6 +1253,24 @@ struct subgroup_radix_sort const std::size_t n_batches = (n_iters + n_batch_size - 1) / n_batch_size; + const auto &kernel_id = sycl::get_kernel_id(); + + auto const &ctx = exec_q.get_context(); + auto const &dev = exec_q.get_device(); + auto kb = sycl::get_kernel_bundle( + ctx, {dev}, {kernel_id}); + + const auto &krn = kb.get_kernel(kernel_id); + + const std::uint32_t krn_sg_size = krn.template get_info< + sycl::info::kernel_device_specific::max_sub_group_size>(dev); + + // due to a bug in CPU device implementation, an additional + // synchronization is necessary for short sub-group sizes + const bool work_around_needed = + exec_q.get_device().has(sycl::aspect::cpu) && + (krn_sg_size < 16); + for (std::size_t batch_id = 0; batch_id < n_batches; ++batch_id) { const std::size_t block_start = batch_id * n_batch_size; @@ -1269,6 +1287,7 @@ struct subgroup_radix_sort sort_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(deps); + cgh.use_kernel_bundle(kb); // allocation to use for value exchanges auto exchange_acc = buf_val.get_acc(cgh); @@ -1357,6 +1376,11 @@ struct subgroup_radix_sort counters[i] = &pcounter[bin * wg_size]; indices[i] = *counters[i]; *counters[i] = indices[i] + 1; + + if (work_around_needed) { + sycl::group_barrier( + ndit.get_group()); + } } } else { @@ -1389,6 +1413,11 @@ struct subgroup_radix_sort counters[i] = &pcounter[bin * wg_size]; indices[i] = *counters[i]; *counters[i] = indices[i] + 1; + + if (work_around_needed) { + sycl::group_barrier( + ndit.get_group()); + } } } From e1b754011e5904a0ac3c8bb99da24e6a60a5eb2a Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk <21087696+oleksandr-pavlyk@users.noreply.github.com> Date: Sat, 28 Dec 2024 11:59:35 -0600 Subject: [PATCH 8/8] Remove skipping of tests for i1/i2 dtypes since work-around was applied in C++. Add tests for 2d input arrays, for axis=0 and axis=1 Add a test for non-contiguous input, 0d input, validation 100% coverage of top_k function implementation achieved --- dpctl/tests/test_usm_ndarray_top_k.py | 162 +++++++++++++++++++++++++- 1 file changed, 156 insertions(+), 6 deletions(-) diff --git a/dpctl/tests/test_usm_ndarray_top_k.py b/dpctl/tests/test_usm_ndarray_top_k.py index 01a83ef293..a27853d8c8 100644 --- a/dpctl/tests/test_usm_ndarray_top_k.py +++ b/dpctl/tests/test_usm_ndarray_top_k.py @@ -55,7 +55,7 @@ def _expected_largest_inds(inp, n, shift, k): @pytest.mark.parametrize( "dtype", [ - pytest.param("i1", marks=pytest.mark.skip(reason="CPU bug")), + "i1", "u1", "i2", "u2", @@ -74,8 +74,6 @@ def _expected_largest_inds(inp, n, shift, k): def test_top_k_1d_largest(dtype, n): q = get_queue_or_skip() skip_if_dtype_not_supported(dtype, q) - if dtype == "i1": - pytest.skip() shift, k = 734, 5 o = dpt.ones(n, dtype=dtype) @@ -89,9 +87,9 @@ def test_top_k_1d_largest(dtype, n): assert s.values.shape == (k,) assert s.values.dtype == inp.dtype assert s.indices.shape == (k,) - assert dpt.all(s.indices == expected_inds) assert dpt.all(s.values == dpt.ones(k, dtype=dtype)), s.values assert dpt.all(s.values == inp[s.indices]), s.indices + assert dpt.all(s.indices == expected_inds), (s.indices, expected_inds) def _expected_smallest_inds(inp, n, shift, k): @@ -128,7 +126,7 @@ def _expected_smallest_inds(inp, n, shift, k): @pytest.mark.parametrize( "dtype", [ - pytest.param("i1", marks=pytest.mark.skip(reason="CPU bug")), + "i1", "u1", "i2", "u2", @@ -160,6 +158,158 @@ def test_top_k_1d_smallest(dtype, n): assert s.values.shape == (k,) assert s.values.dtype == inp.dtype assert s.indices.shape == (k,) - assert dpt.all(s.indices == expected_inds) assert dpt.all(s.values == dpt.zeros(k, dtype=dtype)), s.values assert dpt.all(s.values == inp[s.indices]), s.indices + assert dpt.all(s.indices == expected_inds), (s.indices, expected_inds) + + +@pytest.mark.parametrize( + "dtype", + [ + # skip short types to ensure that m*n can be represented + # in the type + "i4", + "u4", + "i8", + "u8", + "f2", + "f4", + "f8", + "c8", + "c16", + ], +) +@pytest.mark.parametrize("n", [37, 39, 61, 255, 257, 513, 1021, 8193]) +def test_top_k_2d_largest(dtype, n): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + m, k = 8, 3 + if dtype == "f2" and m * n > 2000: + pytest.skip( + "f2 can not distinguish between large integers used in this test" + ) + + x = dpt.reshape(dpt.arange(m * n, dtype=dtype), (m, n)) + + r = dpt.top_k(x, k, axis=1) + + assert r.values.shape == (m, k) + assert r.indices.shape == (m, k) + expected_inds = dpt.reshape(dpt.arange(n, dtype=r.indices.dtype), (1, n))[ + :, -k: + ] + assert expected_inds.shape == (1, k) + assert dpt.all( + dpt.sort(r.indices, axis=1) == dpt.sort(expected_inds, axis=1) + ), (r.indices, expected_inds) + expected_vals = x[:, -k:] + assert dpt.all( + dpt.sort(r.values, axis=1) == dpt.sort(expected_vals, axis=1) + ) + + +@pytest.mark.parametrize( + "dtype", + [ + # skip short types to ensure that m*n can be represented + # in the type + "i4", + "u4", + "i8", + "u8", + "f2", + "f4", + "f8", + "c8", + "c16", + ], +) +@pytest.mark.parametrize("n", [37, 39, 61, 255, 257, 513, 1021, 8193]) +def test_top_k_2d_smallest(dtype, n): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + m, k = 8, 3 + if dtype == "f2" and m * n > 2000: + pytest.skip( + "f2 can not distinguish between large integers used in this test" + ) + + x = dpt.reshape(dpt.arange(m * n, dtype=dtype), (m, n)) + + r = dpt.top_k(x, k, axis=1, mode="smallest") + + assert r.values.shape == (m, k) + assert r.indices.shape == (m, k) + expected_inds = dpt.reshape(dpt.arange(n, dtype=r.indices.dtype), (1, n))[ + :, :k + ] + assert dpt.all( + dpt.sort(r.indices, axis=1) == dpt.sort(expected_inds, axis=1) + ) + assert dpt.all(dpt.sort(r.values, axis=1) == dpt.sort(x[:, :k], axis=1)) + + +def test_top_k_0d(): + get_queue_or_skip() + + a = dpt.ones(tuple(), dtype="i4") + assert a.ndim == 0 + assert a.size == 1 + + r = dpt.top_k(a, 1) + assert r.values == a + assert r.indices == dpt.zeros_like(a, dtype=r.indices.dtype) + + +def test_top_k_noncontig(): + get_queue_or_skip() + + a = dpt.arange(256, dtype=dpt.int32)[::2] + r = dpt.top_k(a, 3) + + assert dpt.all(dpt.sort(r.values) == dpt.asarray([250, 252, 254])), r.values + assert dpt.all( + dpt.sort(r.indices) == dpt.asarray([125, 126, 127]) + ), r.indices + + +def test_top_k_axis0(): + get_queue_or_skip() + + m, n, k = 128, 8, 3 + x = dpt.reshape(dpt.arange(m * n, dtype=dpt.int32), (m, n)) + + r = dpt.top_k(x, k, axis=0, mode="smallest") + assert r.values.shape == (k, n) + assert r.indices.shape == (k, n) + expected_inds = dpt.reshape(dpt.arange(m, dtype=r.indices.dtype), (m, 1))[ + :k, : + ] + assert dpt.all( + dpt.sort(r.indices, axis=0) == dpt.sort(expected_inds, axis=0) + ) + assert dpt.all(dpt.sort(r.values, axis=0) == dpt.sort(x[:k, :], axis=0)) + + +def test_top_k_validation(): + get_queue_or_skip() + x = dpt.ones(10, dtype=dpt.int64) + with pytest.raises(ValueError): + # k must be positive + dpt.top_k(x, -1) + with pytest.raises(TypeError): + # argument should be usm_ndarray + dpt.top_k(list(), 2) + x2 = dpt.reshape(x, (2, 5)) + with pytest.raises(ValueError): + # k must not exceed array dimension + # along specified axis + dpt.top_k(x2, 100, axis=1) + with pytest.raises(ValueError): + # for 0d arrays, k must be 1 + dpt.top_k(x[0], 2) + with pytest.raises(ValueError): + # mode must be "largest", or "smallest" + dpt.top_k(x, 2, mode="invalid")