diff --git a/src/core/reference/include/openvino/reference/search_sorted.hpp b/src/core/reference/include/openvino/reference/search_sorted.hpp index 7ea8ec1078a2a1..629509b28ef78d 100644 --- a/src/core/reference/include/openvino/reference/search_sorted.hpp +++ b/src/core/reference/include/openvino/reference/search_sorted.hpp @@ -32,6 +32,7 @@ void search_sorted(const T* sorted, } const size_t size = shape_size(values_shape); + const size_t sorted_inner_dim = sorted_shape.back(); auto func = [&](size_t i) { auto it = values_transform.begin(); @@ -44,15 +45,12 @@ void search_sorted(const T* sorted, Coordinate sorted_coord_begin = values_coord; sorted_coord_begin.back() = 0; - Coordinate sorted_coord_last = values_coord; - sorted_coord_last.back() = sorted_shape.back(); - const auto sorted_index_begin = coordinate_index(sorted_coord_begin, sorted_shape); - const auto sorted_index_last = coordinate_index(sorted_coord_last, sorted_shape); - - const T* idx_ptr = compare_func(sorted + sorted_index_begin, sorted + sorted_index_last, value); + const T* sorted_begin_ptr = sorted + sorted_index_begin; + const T* sorted_end_ptr = sorted_begin_ptr + sorted_inner_dim; + const T* idx_ptr = compare_func(sorted_begin_ptr, sorted_end_ptr, value); - const ptrdiff_t sorted_index = (idx_ptr - sorted) - sorted_index_begin; + const ptrdiff_t sorted_index = idx_ptr - sorted_begin_ptr; out[values_index] = static_cast(sorted_index); }; diff --git a/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp b/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp index ced915d25610e8..e234bc68de0750 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp @@ -272,6 +272,7 @@ REGISTER_FACTORY(v13, BitwiseXor); REGISTER_FACTORY(v15, ROIAlignRotated); REGISTER_FACTORY(v15, BitwiseRightShift); REGISTER_FACTORY(v15, BitwiseLeftShift); +REGISTER_FACTORY(v15, SearchSorted); // --------------------------- Supported internal ops --------------------------- // REGISTER_FACTORY(internal, NonMaxSuppressionIEInternal); diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/search_sorted.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/search_sorted.hpp new file mode 100644 index 00000000000000..4dfb5c87f8c58c --- /dev/null +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/search_sorted.hpp @@ -0,0 +1,54 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once +#include "primitive.hpp" + +namespace cldnn { + +/// @brief +/// @details +struct search_sorted : public primitive_base { + CLDNN_DECLARE_PRIMITIVE(search_sorted) + + search_sorted() : primitive_base("", {}) {} + + /// @brief Constructs search_sorted primitive. + /// @param id This primitive id. + /// @param sorted Sorted input. + /// @param values Values input. + /// @param right_mode Enable/Disable right mode(check specification for details).. + search_sorted(const primitive_id& id, const input_info& sorted, const input_info& values, bool right_mode) + : primitive_base(id, {sorted, values}), + right_mode(right_mode) {} + + /// @brief Enable/Disable right mode(check specification for details). + bool right_mode = false; + + size_t hash() const override { + size_t seed = primitive::hash(); + seed = hash_combine(seed, right_mode); + return seed; + } + + bool operator==(const primitive& rhs) const override { + if (!compare_common_params(rhs)) + return false; + + auto rhs_casted = downcast(rhs); + + return right_mode == rhs_casted.right_mode; + } + + void save(BinaryOutputBuffer& ob) const override { + primitive_base::save(ob); + ob << right_mode; + } + + void load(BinaryInputBuffer& ib) override { + primitive_base::load(ib); + ib >> right_mode; + } +}; +} // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/register.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/register.cpp index 2597e419e66a41..7f2fab7a6d1581 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/register.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/register.cpp @@ -88,6 +88,7 @@ void register_implementations() { REGISTER_OCL(unique_gather); REGISTER_OCL(scaled_dot_product_attention); REGISTER_OCL(rope); + REGISTER_OCL(search_sorted); } } // namespace ocl diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/register.hpp b/src/plugins/intel_gpu/src/graph/impls/ocl/register.hpp index d4b08b5154ef4b..0a605945fcf6cc 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/register.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/register.hpp @@ -162,6 +162,7 @@ REGISTER_OCL(unique_count); REGISTER_OCL(unique_gather); REGISTER_OCL(scaled_dot_product_attention); REGISTER_OCL(rope); +REGISTER_OCL(search_sorted); #undef REGISTER_OCL diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/search_sorted.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/search_sorted.cpp new file mode 100644 index 00000000000000..4243d75b5c7367 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/search_sorted.cpp @@ -0,0 +1,107 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "primitive_base.hpp" +#include "search_sorted/search_sorted_kernel_base.h" +#include "search_sorted/search_sorted_kernel_selector.h" +#include "search_sorted_inst.h" + +namespace cldnn { +namespace ocl { + +struct search_sorted_impl : typed_primitive_impl_ocl { + using parent = typed_primitive_impl_ocl; + using parent::parent; + using kernel_selector_t = kernel_selector::search_sorted_kernel_selector; + using kernel_params_t = kernel_selector::search_sorted_params; + + DECLARE_OBJECT_TYPE_SERIALIZATION(cldnn::ocl::search_sorted_impl) + + std::unique_ptr clone() const override { + return make_unique(*this); + } + + void load(BinaryInputBuffer& ib) override { + parent::load(ib); + if (is_dynamic()) { + auto& kernel_selector = kernel_selector_t::Instance(); + auto kernel_impl = kernel_selector.GetImplementation(_kernel_data.kernelName); + kernel_impl->GetUpdateDispatchDataFunc(_kernel_data); + } + } + + void update_dispatch_data(const kernel_impl_params& impl_param) override { + // If model loaded from cache, params are not initialized, so we create a new object and reuse it in the future + if (_kernel_data.params == nullptr) { + _kernel_data.params = std::make_shared(get_kernel_params(impl_param, true)); + } + + update_shapes(*_kernel_data.params, impl_param); + (_kernel_data.update_dispatch_data_func)(*_kernel_data.params, _kernel_data); + } + + static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param, bool shape_agnostic = false) { + const auto& primitive = impl_param.typed_desc(); + auto params = get_default_params(impl_param, shape_agnostic); + + // Manually add all inputs except first one, since get_default_params does not handle it. + for (size_t i = 1; i < impl_param.input_layouts.size(); ++i) { + params.inputs.push_back(convert_data_tensor(impl_param.get_input_layout(i))); + } + + params.right_mode = primitive->right_mode; + return params; + } + + // [NOTE]: Has to be added as a separete static function, since it is called via static dispatching in + // typed_primitive_impl_ocl::create().. + static kernel_impl_params static_canonicalize_shapes(const kernel_impl_params& impl_params) { + auto updated_impl_params = canonicalize_fused_shapes(impl_params); + + for (auto& input_layout : updated_impl_params.input_layouts) { + input_layout.set_partial_shape(extend_shape_to_rank_from_begin(input_layout.get_partial_shape())); + } + + for (auto& output_layout : updated_impl_params.output_layouts) { + output_layout.set_partial_shape(extend_shape_to_rank_from_begin(output_layout.get_partial_shape())); + } + + return updated_impl_params; + } + + kernel_impl_params canonicalize_shapes(const kernel_impl_params& impl_params) const override { + return static_canonicalize_shapes(impl_params); + } +}; + +namespace detail { + +attach_search_sorted_impl::attach_search_sorted_impl() { + auto types = { + data_types::i8, + data_types::u8, + data_types::i16, + data_types::u16, + data_types::i32, + data_types::u32, + data_types::i64, + data_types::f16, + data_types::f32, + }; + + auto formats = {format::bfyx, format::bfzyx}; + + implementation_map::add(impl_types::ocl, + shape_types::any, + typed_primitive_impl_ocl::create, + types, + formats); +} + +} // namespace detail +} // namespace ocl +} // namespace cldnn + +BIND_BINARY_BUFFER_WITH_TYPE(cldnn::ocl::search_sorted_impl) +BIND_BINARY_BUFFER_WITH_TYPE(cldnn::search_sorted) diff --git a/src/plugins/intel_gpu/src/graph/impls/registry/registry.hpp b/src/plugins/intel_gpu/src/graph/impls/registry/registry.hpp index a6bb8ad6eebcc2..77c4262a7513cc 100644 --- a/src/plugins/intel_gpu/src/graph/impls/registry/registry.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/registry/registry.hpp @@ -214,3 +214,4 @@ REGISTER_DEFAULT_IMPLS(unique_count, OCL_S, OCL_D); REGISTER_DEFAULT_IMPLS(unique_gather, OCL_S, OCL_D); REGISTER_DEFAULT_IMPLS(scaled_dot_product_attention, OCL_S, OCL_D); REGISTER_DEFAULT_IMPLS(rope, OCL_S, OCL_D); +REGISTER_DEFAULT_IMPLS(search_sorted, OCL_S, OCL_D); diff --git a/src/plugins/intel_gpu/src/graph/include/search_sorted_inst.h b/src/plugins/intel_gpu/src/graph/include/search_sorted_inst.h new file mode 100644 index 00000000000000..50ffdf8112e2ae --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/include/search_sorted_inst.h @@ -0,0 +1,46 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#include "primitive_inst.h" + +namespace cldnn { + +template <> +struct typed_program_node : public typed_program_node_base { + using parent = typed_program_node_base; + typed_program_node(const std::shared_ptr prim, program& prog) : parent(prim, prog) {} + +public: + using parent::parent; + + program_node& input(size_t idx = 0) const { + return get_dependency(idx); + } + std::vector get_shape_infer_dependencies() const override { + return {}; + } +}; + +using search_sorted_node = typed_program_node; + +template <> +class typed_primitive_inst : public typed_primitive_inst_base { + using parent = typed_primitive_inst_base; + using parent::parent; + +public: + typed_primitive_inst(network& network, search_sorted_node const& desc); + template + static std::vector calc_output_layouts(search_sorted_node const& node, + kernel_impl_params const& impl_param); + static layout calc_output_layout(search_sorted_node const& node, kernel_impl_params const& impl_param); + static std::string to_string(search_sorted_node const& node); +}; + +using search_sorted_inst = typed_primitive_inst; + +} // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/search_sorted.cpp b/src/plugins/intel_gpu/src/graph/search_sorted.cpp new file mode 100644 index 00000000000000..761b6751ace3b7 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/search_sorted.cpp @@ -0,0 +1,59 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include + +#include "openvino/core/enum_names.hpp" +#include "primitive_type_base.h" +#include "search_sorted_shape_inference.hpp" + +namespace cldnn { +GPU_DEFINE_PRIMITIVE_TYPE_ID(search_sorted) + +search_sorted_inst::typed_primitive_inst(network& network, search_sorted_node const& node) : parent(network, node) {} + +layout search_sorted_inst::calc_output_layout(search_sorted_node const& node, kernel_impl_params const& impl_param) { + return calc_output_layouts(node, impl_param)[0]; +} + +template +std::vector search_sorted_inst::calc_output_layouts(search_sorted_node const& node, + kernel_impl_params const& impl_param) { + auto primitive = impl_param.typed_desc(); + + auto input0_layout = impl_param.get_input_layout(0); + auto input1_layout = impl_param.get_input_layout(1); + + const data_types output_type = impl_param.desc->output_data_types[0].value_or(data_types::i64); + + std::vector input_shapes = { + input0_layout.get(), // sorted shape + input1_layout.get(), // values shape + }; + + std::vector output_shapes; + + ov::op::v15::SearchSorted op; + op.set_right_mode(primitive->right_mode); + output_shapes = shape_infer(&op, input_shapes); + + return {layout{output_shapes[0], output_type, input1_layout.format}}; +} + +std::string search_sorted_inst::to_string(search_sorted_node const& node) { + auto node_info = node.desc_to_json(); + json_composite search_sorted_info; + search_sorted_info.add("sorted id", node.input(0).id()); + search_sorted_info.add("values id", node.input(1).id()); + search_sorted_info.add("right_mode", node.get_primitive()->right_mode); + node_info->add("search_sorted info", search_sorted_info); + std::stringstream primitive_description; + node_info->dump(primitive_description); + return primitive_description.str(); +} + +} // namespace cldnn \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/search_sorted_ref.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/search_sorted_ref.cl new file mode 100644 index 00000000000000..b9e26405688f12 --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/search_sorted_ref.cl @@ -0,0 +1,56 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "include/batch_headers/fetch_data.cl" + +#if RIGHT_MODE == 0 +#define CMP <= +#else +#define CMP < +#endif + +OUTPUT_TYPE FUNC(binary_search_thread)(const INPUT0_TYPE search_val, + const __global INPUT0_TYPE* restrict sorted, + OUTPUT_TYPE sorted_begin_idx, + OUTPUT_TYPE sorted_end_idx) { + while(sorted_begin_idx != sorted_end_idx) { + const OUTPUT_TYPE half_offset = (sorted_end_idx-sorted_begin_idx)/2; + const OUTPUT_TYPE half_idx = sorted_begin_idx+half_offset; + const INPUT0_TYPE half_val = sorted[half_idx]; + if ( search_val CMP half_val ) + sorted_end_idx = half_idx; + else + sorted_begin_idx = half_idx + 1; + } + + return sorted_begin_idx; +} + +#undef CMP + +KERNEL(search_sorted_ref)( + OPTIONAL_SHAPE_INFO_ARG + const __global INPUT0_TYPE* restrict sorted, + const __global INPUT1_TYPE* restrict values, + __global OUTPUT_TYPE* restrict output) +{ + // INPUT0_TYPE has to be egual to INPUT1_TYPE + const int this_thread_idx = get_global_id(0); + const INPUT0_TYPE search_val = values[this_thread_idx]; + + const int SORTED_STRIDE = INPUT0_BATCH_NUM*INPUT0_FEATURE_NUM*INPUT0_SIZE_Y*INPUT0_SIZE_Z; + + // NOTE: SORTED_STRIDE-1 handles here a special case when sorted is actually 1D + // tensor and values is ND tensor. In such case we effectively want sorted_offset + // to be 0. + const int sorted_offset = min(this_thread_idx/INPUT1_SIZE_X, SORTED_STRIDE-1); + + OUTPUT_TYPE sorted_begin_idx = sorted_offset * INPUT0_SIZE_X; + const OUTPUT_TYPE idx = FUNC_CALL(binary_search_thread)(search_val, + sorted + sorted_begin_idx, + 0, + INPUT0_SIZE_X); + + output[this_thread_idx] = idx; +} \ No newline at end of file diff --git a/src/plugins/intel_gpu/src/kernel_selector/common_types.h b/src/plugins/intel_gpu/src/kernel_selector/common_types.h index bc9cc9f5b8da07..37139dbaeeffd2 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/common_types.h +++ b/src/plugins/intel_gpu/src/kernel_selector/common_types.h @@ -101,7 +101,8 @@ enum class KernelType { RMS, SWIGLU, ROPE, - DYNAMIC_QUANTIZE + DYNAMIC_QUANTIZE, + SEARCH_SORTED }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/search_sorted/search_sorted_kernel_base.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/search_sorted/search_sorted_kernel_base.cpp new file mode 100644 index 00000000000000..ce4527a1f93aa7 --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/search_sorted/search_sorted_kernel_base.cpp @@ -0,0 +1,72 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "search_sorted_kernel_base.h" + +#include + +#include "kernel_selector_utils.h" + +namespace kernel_selector { +JitConstants SearchSortedKernelBase::GetJitConstants(const search_sorted_params& params) const { + JitConstants jit = MakeBaseParamsJitConstants(params); + + jit.AddConstants({MakeJitConstant("RIGHT_MODE", params.right_mode)}); + + return jit; +} + +void SearchSortedKernelBase::GetUpdateDispatchDataFunc(KernelData& kd) const { + kd.update_dispatch_data_func = [](const Params& params, KernelData& kd) { + const auto& prim_params = static_cast(params); + auto dispatchData = SetDefault(prim_params); + OPENVINO_ASSERT(kd.kernels.size() == 1, "[GPU] Invalid kernels size for update dispatch data func"); + kd.kernels[0].params.workGroups.global = dispatchData.gws; + kd.kernels[0].params.workGroups.local = dispatchData.lws; + kd.kernels[0].skip_execution = KernelData::SkipKernelExecution(prim_params); + }; +} + +SearchSortedKernelBase::DispatchData SearchSortedKernelBase::SetDefault(const search_sorted_params& params) { + DispatchData dispatchData; + dispatchData.gws[0] = params.outputs[0].LogicalSize(); + dispatchData.gws[1] = 1; + dispatchData.gws[2] = 1; + dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo); + + return dispatchData; +} + +KernelsData SearchSortedKernelBase::GetCommonKernelsData(const Params& params) const { + assert(params.GetType() == KernelType::SEARCH_SORTED); + + const auto& prim_params = static_cast(params); + + auto dispatchData = SetDefault(prim_params); + KernelData k_data = KernelData::Default(params); + + auto cldnn_jit = GetJitConstants(prim_params); + auto entry_point = GetEntryPoint(kernelName, prim_params.layerID, params); + auto jit = CreateJit(kernelName, cldnn_jit, entry_point); + + GetUpdateDispatchDataFunc(k_data); + + auto& kernel = k_data.kernels[0]; + FillCLKernelData(kernel, + dispatchData, + params.engineInfo, + kernelName, + jit, + entry_point, + "", + false, + false, + 2, + GetFusedPrimitiveInputsCount(params), + 1, + prim_params.is_shape_agnostic); + + return {k_data}; +} +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/search_sorted/search_sorted_kernel_base.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/search_sorted/search_sorted_kernel_base.h new file mode 100644 index 00000000000000..734229b6645fd6 --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/search_sorted/search_sorted_kernel_base.h @@ -0,0 +1,34 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "kernel_base_opencl.h" +#include "kernel_selector_params.h" + +namespace kernel_selector { +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// search_sorted +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +struct search_sorted_params : public base_params { + search_sorted_params() : base_params(KernelType::SEARCH_SORTED), right_mode(false) {} + bool right_mode; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// SearchSortedKernelBase +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +class SearchSortedKernelBase : public KernelBaseOpenCL { +public: + using KernelBaseOpenCL::KernelBaseOpenCL; + + using DispatchData = CommonDispatchData; + +protected: + JitConstants GetJitConstants(const search_sorted_params& params) const; + static DispatchData SetDefault(const search_sorted_params& params); + KernelsData GetCommonKernelsData(const Params& params) const; + void GetUpdateDispatchDataFunc(KernelData& kd) const override; +}; +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/search_sorted/search_sorted_kernel_ref.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/search_sorted/search_sorted_kernel_ref.cpp new file mode 100644 index 00000000000000..5bbd22f24ebfec --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/search_sorted/search_sorted_kernel_ref.cpp @@ -0,0 +1,45 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "search_sorted_kernel_ref.h" + +namespace kernel_selector { +ParamsKey SearchSortedKernelRef::GetSupportedKey() const { + ParamsKey k; + + k.EnableInputDataType(Datatype::INT8); + k.EnableInputDataType(Datatype::UINT8); + k.EnableInputDataType(Datatype::INT16); + k.EnableInputDataType(Datatype::UINT16); + k.EnableInputDataType(Datatype::INT32); + k.EnableInputDataType(Datatype::UINT32); + k.EnableInputDataType(Datatype::INT64); + k.EnableInputDataType(Datatype::F32); + k.EnableInputDataType(Datatype::F16); + + k.EnableOutputDataType(Datatype::INT32); + k.EnableOutputDataType(Datatype::INT64); + + k.EnableInputLayout(DataLayout::bfyx); + k.EnableInputLayout(DataLayout::bfzyx); + + k.EnableOutputLayout(DataLayout::bfyx); + k.EnableOutputLayout(DataLayout::bfzyx); + + k.EnableTensorOffset(); + k.EnableTensorPitches(); + k.EnableBatching(); + k.EnableDifferentTypes(); + k.EnableDynamicShapesSupport(); + return k; +} + +KernelsData SearchSortedKernelRef::GetKernelsData(const Params& params) const { + return GetCommonKernelsData(params); +} + +KernelsPriority SearchSortedKernelRef::GetKernelsPriority(const Params& /*params*/) const { + return FORCE_PRIORITY_9; +} +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/search_sorted/search_sorted_kernel_ref.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/search_sorted/search_sorted_kernel_ref.h new file mode 100644 index 00000000000000..bc7738013c4867 --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/search_sorted/search_sorted_kernel_ref.h @@ -0,0 +1,18 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "search_sorted_kernel_base.h" + +namespace kernel_selector { +class SearchSortedKernelRef : public SearchSortedKernelBase { +public: + SearchSortedKernelRef() : SearchSortedKernelBase("search_sorted_ref") {} + + KernelsData GetKernelsData(const Params& params) const override; + KernelsPriority GetKernelsPriority(const Params& params) const override; + ParamsKey GetSupportedKey() const override; +}; +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/search_sorted/search_sorted_kernel_selector.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/search_sorted/search_sorted_kernel_selector.cpp new file mode 100644 index 00000000000000..b83c4d09fd56dd --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/search_sorted/search_sorted_kernel_selector.cpp @@ -0,0 +1,14 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "search_sorted_kernel_selector.h" +#include "search_sorted_kernel_ref.h" + +namespace kernel_selector { +search_sorted_kernel_selector::search_sorted_kernel_selector() { Attach(); } + +KernelsData search_sorted_kernel_selector::GetBestKernels(const Params& params) const { + return GetNaiveBestKernel(params, KernelType::SEARCH_SORTED); +} +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/search_sorted/search_sorted_kernel_selector.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/search_sorted/search_sorted_kernel_selector.h new file mode 100644 index 00000000000000..25f9a30fb0d895 --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/search_sorted/search_sorted_kernel_selector.h @@ -0,0 +1,21 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "kernel_selector.h" + +namespace kernel_selector { +class search_sorted_kernel_selector : public kernel_selector_base { +public: + static search_sorted_kernel_selector& Instance() { + static search_sorted_kernel_selector instance; + return instance; + } + + search_sorted_kernel_selector(); + + KernelsData GetBestKernels(const Params& params) const override; +}; +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/plugin/ops/search_sorted.cpp b/src/plugins/intel_gpu/src/plugin/ops/search_sorted.cpp new file mode 100644 index 00000000000000..dbb4fecbd66ab5 --- /dev/null +++ b/src/plugins/intel_gpu/src/plugin/ops/search_sorted.cpp @@ -0,0 +1,25 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/search_sorted.hpp" + +#include "intel_gpu/plugin/common_utils.hpp" +#include "intel_gpu/plugin/program_builder.hpp" +#include "intel_gpu/primitives/search_sorted.hpp" + +namespace ov { +namespace intel_gpu { + +static void CreateSearchSortedOp(ProgramBuilder& p, const std::shared_ptr& op) { + validate_inputs_count(op, {2}); + auto inputs = p.GetInputInfo(op); + auto prim = cldnn::search_sorted(layer_type_name_ID(op), inputs[0], inputs[1], op->get_right_mode()); + prim.output_data_types = get_output_data_types(op, {{ov::element::i64, ov::element::i32}}); + p.add_primitive(*op, prim); +} + +REGISTER_FACTORY_IMPL(v15, SearchSorted); + +} // namespace intel_gpu +} // namespace ov diff --git a/src/plugins/intel_gpu/src/plugin/program_builder.cpp b/src/plugins/intel_gpu/src/plugin/program_builder.cpp index 899110872ba633..a2316270e1ef3a 100644 --- a/src/plugins/intel_gpu/src/plugin/program_builder.cpp +++ b/src/plugins/intel_gpu/src/plugin/program_builder.cpp @@ -8,6 +8,7 @@ #include "openvino/op/variadic_split.hpp" #include "openvino/op/lstm_cell.hpp" #include "openvino/op/loop.hpp" +#include "openvino/op/search_sorted.hpp" #include "intel_gpu/plugin/common_utils.hpp" #include "intel_gpu/plugin/program_builder.hpp" @@ -349,6 +350,12 @@ bool ProgramBuilder::requires_new_shape_infer(const std::shared_ptr& o return true; } + // HACK: SearchSorted has specific shape requirements. + // E.g. static input shapes: sorted:[8], values:[2,3,4] are prefectly fine, + // but sorted:[8,1,1,1], values:[2,3,4,1] is not valid. + if (ov::is_type(op)) + return true; + if (ov::is_type(op)) { const auto body_function = std::static_pointer_cast(op)->get_function(); if (body_function->is_dynamic()) diff --git a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/single_layer_tests/search_sorted.cpp b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/single_layer_tests/search_sorted.cpp new file mode 100644 index 00000000000000..0117463880a607 --- /dev/null +++ b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/single_layer_tests/search_sorted.cpp @@ -0,0 +1,18 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "single_op_tests/search_sorted.hpp" + +namespace ov { +namespace test { + +INSTANTIATE_TEST_SUITE_P(smoke_SearchSortedTest, + SearchSortedLayerTest, + ::testing::Combine(::testing::ValuesIn(SearchSortedLayerTest::GenerateParams()), + testing::Values(ElementType::f32, ElementType::f16, ElementType::i64, ElementType::u32), + testing::Values(ov::test::utils::DEVICE_GPU)), + SearchSortedLayerTest::getTestCaseName); + +} // namespace test +} // namespace ov diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/search_sorted_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/search_sorted_gpu_test.cpp new file mode 100644 index 00000000000000..f9dfa0aeb0fc2b --- /dev/null +++ b/src/plugins/intel_gpu/tests/unit/test_cases/search_sorted_gpu_test.cpp @@ -0,0 +1,148 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include + +#include "test_utils.h" + +using namespace cldnn; +using namespace ::tests; + +namespace { + +constexpr float EPS = 2e-3f; + +namespace helpers { +// TODO: Move to common place. + +// Converts float vector to another type vector. +template +std::vector ConverFloatVector(const std::vector& vec) { + std::vector ret; + ret.reserve(vec.size()); + for (const auto& val : vec) { + ret.push_back(T(val)); + } + return ret; +} + +// Allocates tensoer with given shape and data. +template +memory::ptr AllocateTensor(ov::PartialShape shape, const std::vector& data) { + const layout lo = {shape, ov::element::from(), cldnn::format::bfyx}; + EXPECT_EQ(lo.get_linear_size(), data.size()); + memory::ptr tensor = get_test_engine().allocate_memory(lo); + set_values(tensor, data); + return tensor; +} +} // namespace helpers + +struct SearchSortedTestParams { + ov::PartialShape sortedShape; + ov::PartialShape valuesShape; + bool rightMode; + std::vector sortedData; + std::vector valuesData; + std::vector expectedOutput; + std::string testcaseName; +}; + +class search_sorted_test : public ::testing::TestWithParam { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + auto param = obj.param; + std::ostringstream result; + result << "sortedShape=" << param.sortedShape; + result << "_valuesShape=" << param.valuesShape; + result << "_rightMode=" << param.rightMode; + result << "_" << param.testcaseName; + return result.str(); + } + + struct SearchSortedInferenceParams { + bool rightMode; + memory::ptr sorted; + memory::ptr values; + memory::ptr expectedOutput; + }; + + template + SearchSortedInferenceParams PrepareInferenceParams(const SearchSortedTestParams& testParam) { + using T = typename ov::element_type_traits::value_type; + SearchSortedInferenceParams ret; + + ret.rightMode = testParam.rightMode; + + ret.sorted = + helpers::AllocateTensor(testParam.sortedShape, helpers::ConverFloatVector(testParam.sortedData)); + ret.values = + helpers::AllocateTensor(testParam.valuesShape, helpers::ConverFloatVector(testParam.valuesData)); + ret.expectedOutput = helpers::AllocateTensor(testParam.valuesShape, testParam.expectedOutput); + + return ret; + } + + void Execute(const SearchSortedInferenceParams& params) { + // Prepare the network. + auto stream = get_test_stream_ptr(get_test_default_config(engine_)); + + topology topology; + topology.add(input_layout("sorted", params.sorted->get_layout())); + topology.add(input_layout("values", params.values->get_layout())); + topology.add(search_sorted("search_sorted", input_info("sorted"), input_info("values"), params.rightMode)); + + cldnn::network::ptr network = get_network(engine_, topology, get_test_default_config(engine_), stream, false); + + network->set_input_data("sorted", params.sorted); + network->set_input_data("values", params.values); + + // Run and check results. + auto outputs = network->execute(); + + auto output = outputs.at("search_sorted").get_memory(); + cldnn::mem_lock output_ptr(output, get_test_stream()); + cldnn::mem_lock wanted_output_ptr(params.expectedOutput, get_test_stream()); + + ASSERT_EQ(output->get_layout(), params.expectedOutput->get_layout()); + ASSERT_EQ(output_ptr.size(), wanted_output_ptr.size()); + for (size_t i = 0; i < output_ptr.size(); ++i) + ASSERT_TRUE(are_equal(wanted_output_ptr[i], output_ptr[i], EPS)); + } + +private: + engine& engine_ = get_test_engine(); +}; + +std::vector generateTestParams() { + std::vector params; +#define TEST_DATA(sorted_shape, values_shape, right_mode, sorted_data, values_data, expected_output_data, description) \ + params.push_back(SearchSortedTestParams{sorted_shape, \ + values_shape, \ + right_mode, \ + sorted_data, \ + values_data, \ + expected_output_data, \ + description}); + +#include "unit_test_utils/tests_data/search_sorted_data.h" +#undef TEST_DATA + return params; +} + +} // namespace + +#define SEARCH_SORTED_TEST_P(precision) \ + TEST_P(search_sorted_test, ref_comp_##precision) { \ + Execute(PrepareInferenceParams(GetParam())); \ + } + +SEARCH_SORTED_TEST_P(f16); +SEARCH_SORTED_TEST_P(u8); + +INSTANTIATE_TEST_SUITE_P(search_sorted_test_suit, + search_sorted_test, + testing::ValuesIn(generateTestParams()), + search_sorted_test::getTestCaseName); diff --git a/src/tests/functional/shared_test_classes/src/single_op/search_sorted.cpp b/src/tests/functional/shared_test_classes/src/single_op/search_sorted.cpp index a92d87d51f9a10..c7c10ad8767ff6 100644 --- a/src/tests/functional/shared_test_classes/src/single_op/search_sorted.cpp +++ b/src/tests/functional/shared_test_classes/src/single_op/search_sorted.cpp @@ -88,11 +88,30 @@ void SearchSortedLayerTest::SetUp() { const std::vector SearchSortedLayerTest::GenerateParams() { const std::vector params = { + SearchSortedSpecificParams{InputShape{PartialShape::dynamic(3), {{1, 18, 104}}}, + InputShape{PartialShape::dynamic(3), {{1, 18, 104}}}, + true}, + SearchSortedSpecificParams{InputShape{PartialShape::dynamic(4), {{1, 2, 3, 100}}}, + InputShape{PartialShape::dynamic(4), {{1, 2, 3, 10}}}, + true}, + SearchSortedSpecificParams{InputShape{PartialShape::dynamic(5), {{2, 1, 2, 3, 10}}}, + InputShape{PartialShape::dynamic(5), {{2, 1, 2, 3, 20}}}, + false}, + SearchSortedSpecificParams{InputShape{PartialShape::dynamic(1), {{1}}}, + InputShape{PartialShape::dynamic(5), {{2, 1, 2, 3, 20}}}, + false}, + SearchSortedSpecificParams{InputShape{PartialShape::dynamic(1), {{50}}}, + InputShape{{1, -1, 10}, {{1, 18, 10}}}, + false}, SearchSortedSpecificParams{InputShape{{}, {{1, 18, 104}}}, InputShape{{}, {{1, 18, 104}}}, true}, SearchSortedSpecificParams{InputShape{{}, {{1, 2, 3, 100}}}, InputShape{{}, {{1, 2, 3, 10}}}, true}, SearchSortedSpecificParams{InputShape{{}, {{2, 1, 2, 3, 10}}}, InputShape{{}, {{2, 1, 2, 3, 20}}}, false}, SearchSortedSpecificParams{InputShape{{}, {{1}}}, InputShape{{}, {{2, 1, 2, 3, 20}}}, false}, SearchSortedSpecificParams{InputShape{{}, {{50}}}, InputShape{{1, -1, 10}, {{1, 18, 10}}}, false}, + SearchSortedSpecificParams{InputShape{{2, -1, 50}, {{2, 3, 50}}}, + InputShape{{-1, -1, 10}, {{2, 3, 10}}}, + false}, + SearchSortedSpecificParams{InputShape{{2, -1, 50}, {{2, 3, 50}}}, InputShape{{-1, 3, 10}, {{2, 3, 10}}}, false}, }; return params; diff --git a/src/tests/test_utils/unit_test_utils/tests_data/search_sorted_data.h b/src/tests/test_utils/unit_test_utils/tests_data/search_sorted_data.h index ee355c2daee15e..43e680aa080686 100644 --- a/src/tests/test_utils/unit_test_utils/tests_data/search_sorted_data.h +++ b/src/tests/test_utils/unit_test_utils/tests_data/search_sorted_data.h @@ -13,6 +13,22 @@ // NOTE: expected output were generated using pyTorch.searchsorted implementation. +TEST_DATA(LIST(5), + LIST(2, 3), + false, + LIST(3, 3, 3, 3, 3), + LIST(3, 6, 9, 3, 6, 9), + LIST(0, 5, 5, 0, 5, 5), + "1d_tensor_0"); + +TEST_DATA(LIST(5), + LIST(2, 3), + true, + LIST(3, 3, 3, 3, 3), + LIST(3, 6, 9, 3, 6, 9), + LIST(5, 5, 5, 5, 5, 5), + "1d_tensor_0_right_mode"); + TEST_DATA(LIST(5), LIST(2, 3), false, @@ -53,6 +69,22 @@ TEST_DATA(LIST(5), LIST(0, 3, 5, 1, 3, 5, 1, 0, 0, 5, 5, 5), "1d_tensor_3_right_mode"); +TEST_DATA(LIST(1), + LIST(2, 2, 3), + false, + LIST(2), + LIST(0, 6, 20, 2, 6, 9, 1, 0, 0, 9, 10, 20), + LIST(0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1), + "1d_tensor_4"); + +TEST_DATA(LIST(1), + LIST(2, 2, 3), + true, + LIST(2), + LIST(0, 6, 20, 2, 6, 9, 1, 0, 0, 9, 10, 20), + LIST(0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1), + "1d_tensor_4_right_mode"); + TEST_DATA(LIST(2, 5), LIST(2, 3), false, @@ -72,15 +104,15 @@ TEST_DATA(LIST(2, 5), TEST_DATA(LIST(2, 2, 5), LIST(2, 2, 3), false, - LIST(1, 3, 5, 7, 9, 0, 2, 4, 6, 8, -20, 5, 10, 23, 41, 100, 125, 130, 132, 139), + LIST(1, 3, 5, 7, 9, 0, 2, 4, 6, 8, 0, 5, 10, 23, 41, 100, 125, 130, 132, 139), LIST(0, 6, 20, 1, 6, 9, 1, 0, 0, 9, 10, 20), - LIST(0, 3, 5, 1, 3, 5, 1, 1, 1, 0, 0, 0), + LIST(0, 3, 5, 1, 3, 5, 1, 0, 0, 0, 0, 0), "nd_tensor_2"); TEST_DATA(LIST(2, 2, 5), LIST(2, 2, 3), true, - LIST(1, 3, 5, 7, 9, 0, 2, 4, 6, 8, -20, 5, 10, 23, 41, 100, 125, 130, 132, 139), + LIST(1, 3, 5, 7, 9, 0, 2, 4, 6, 8, 0, 5, 10, 23, 41, 100, 125, 130, 132, 139), LIST(0, 6, 20, 1, 6, 9, 1, 0, 0, 9, 10, 20), LIST(0, 3, 5, 1, 4, 5, 1, 1, 1, 0, 0, 0), - "nd_tensor_2"); \ No newline at end of file + "nd_tensor_2_right_mode"); \ No newline at end of file