Skip to content

Commit

Permalink
[GPU]: SearchSorted basic implementation. (#27356)
Browse files Browse the repository at this point in the history
Added GPU reference SearchSorted op implementation with unit and func
tests. Kernel supports dynamic shapes.

### Details:
- Fixed a bug in reference implementation, when sorted had exactly one
element. Added tests for that case.

### Tickets:
 - CVS-156238
  • Loading branch information
pkowalc1 authored Nov 26, 2024
1 parent 1b72ca5 commit c85e88b
Show file tree
Hide file tree
Showing 23 changed files with 790 additions and 12 deletions.
12 changes: 5 additions & 7 deletions src/core/reference/include/openvino/reference/search_sorted.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ void search_sorted(const T* sorted,
}

const size_t size = shape_size(values_shape);
const size_t sorted_inner_dim = sorted_shape.back();

auto func = [&](size_t i) {
auto it = values_transform.begin();
Expand All @@ -44,15 +45,12 @@ void search_sorted(const T* sorted,
Coordinate sorted_coord_begin = values_coord;
sorted_coord_begin.back() = 0;

Coordinate sorted_coord_last = values_coord;
sorted_coord_last.back() = sorted_shape.back();

const auto sorted_index_begin = coordinate_index(sorted_coord_begin, sorted_shape);
const auto sorted_index_last = coordinate_index(sorted_coord_last, sorted_shape);

const T* idx_ptr = compare_func(sorted + sorted_index_begin, sorted + sorted_index_last, value);
const T* sorted_begin_ptr = sorted + sorted_index_begin;
const T* sorted_end_ptr = sorted_begin_ptr + sorted_inner_dim;
const T* idx_ptr = compare_func(sorted_begin_ptr, sorted_end_ptr, value);

const ptrdiff_t sorted_index = (idx_ptr - sorted) - sorted_index_begin;
const ptrdiff_t sorted_index = idx_ptr - sorted_begin_ptr;

out[values_index] = static_cast<TOut>(sorted_index);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ REGISTER_FACTORY(v13, BitwiseXor);
REGISTER_FACTORY(v15, ROIAlignRotated);
REGISTER_FACTORY(v15, BitwiseRightShift);
REGISTER_FACTORY(v15, BitwiseLeftShift);
REGISTER_FACTORY(v15, SearchSorted);

// --------------------------- Supported internal ops --------------------------- //
REGISTER_FACTORY(internal, NonMaxSuppressionIEInternal);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once
#include "primitive.hpp"

namespace cldnn {

/// @brief
/// @details
struct search_sorted : public primitive_base<search_sorted> {
CLDNN_DECLARE_PRIMITIVE(search_sorted)

search_sorted() : primitive_base("", {}) {}

/// @brief Constructs search_sorted primitive.
/// @param id This primitive id.
/// @param sorted Sorted input.
/// @param values Values input.
/// @param right_mode Enable/Disable right mode(check specification for details)..
search_sorted(const primitive_id& id, const input_info& sorted, const input_info& values, bool right_mode)
: primitive_base(id, {sorted, values}),
right_mode(right_mode) {}

/// @brief Enable/Disable right mode(check specification for details).
bool right_mode = false;

size_t hash() const override {
size_t seed = primitive::hash();
seed = hash_combine(seed, right_mode);
return seed;
}

bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;

auto rhs_casted = downcast<const search_sorted>(rhs);

return right_mode == rhs_casted.right_mode;
}

void save(BinaryOutputBuffer& ob) const override {
primitive_base<search_sorted>::save(ob);
ob << right_mode;
}

void load(BinaryInputBuffer& ib) override {
primitive_base<search_sorted>::load(ib);
ib >> right_mode;
}
};
} // namespace cldnn
1 change: 1 addition & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ void register_implementations() {
REGISTER_OCL(unique_gather);
REGISTER_OCL(scaled_dot_product_attention);
REGISTER_OCL(rope);
REGISTER_OCL(search_sorted);
}

} // namespace ocl
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/register.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ REGISTER_OCL(unique_count);
REGISTER_OCL(unique_gather);
REGISTER_OCL(scaled_dot_product_attention);
REGISTER_OCL(rope);
REGISTER_OCL(search_sorted);

#undef REGISTER_OCL

Expand Down
107 changes: 107 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/search_sorted.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "primitive_base.hpp"
#include "search_sorted/search_sorted_kernel_base.h"
#include "search_sorted/search_sorted_kernel_selector.h"
#include "search_sorted_inst.h"

namespace cldnn {
namespace ocl {

struct search_sorted_impl : typed_primitive_impl_ocl<search_sorted> {
using parent = typed_primitive_impl_ocl<search_sorted>;
using parent::parent;
using kernel_selector_t = kernel_selector::search_sorted_kernel_selector;
using kernel_params_t = kernel_selector::search_sorted_params;

DECLARE_OBJECT_TYPE_SERIALIZATION(cldnn::ocl::search_sorted_impl)

std::unique_ptr<primitive_impl> clone() const override {
return make_unique<search_sorted_impl>(*this);
}

void load(BinaryInputBuffer& ib) override {
parent::load(ib);
if (is_dynamic()) {
auto& kernel_selector = kernel_selector_t::Instance();
auto kernel_impl = kernel_selector.GetImplementation(_kernel_data.kernelName);
kernel_impl->GetUpdateDispatchDataFunc(_kernel_data);
}
}

void update_dispatch_data(const kernel_impl_params& impl_param) override {
// If model loaded from cache, params are not initialized, so we create a new object and reuse it in the future
if (_kernel_data.params == nullptr) {
_kernel_data.params = std::make_shared<kernel_params_t>(get_kernel_params(impl_param, true));
}

update_shapes(*_kernel_data.params, impl_param);
(_kernel_data.update_dispatch_data_func)(*_kernel_data.params, _kernel_data);
}

static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param, bool shape_agnostic = false) {
const auto& primitive = impl_param.typed_desc<search_sorted>();
auto params = get_default_params<kernel_selector::search_sorted_params>(impl_param, shape_agnostic);

// Manually add all inputs except first one, since get_default_params does not handle it.
for (size_t i = 1; i < impl_param.input_layouts.size(); ++i) {
params.inputs.push_back(convert_data_tensor(impl_param.get_input_layout(i)));
}

params.right_mode = primitive->right_mode;
return params;
}

// [NOTE]: Has to be added as a separete static function, since it is called via static dispatching in
// typed_primitive_impl_ocl::create()..
static kernel_impl_params static_canonicalize_shapes(const kernel_impl_params& impl_params) {
auto updated_impl_params = canonicalize_fused_shapes(impl_params);

for (auto& input_layout : updated_impl_params.input_layouts) {
input_layout.set_partial_shape(extend_shape_to_rank_from_begin(input_layout.get_partial_shape()));
}

for (auto& output_layout : updated_impl_params.output_layouts) {
output_layout.set_partial_shape(extend_shape_to_rank_from_begin(output_layout.get_partial_shape()));
}

return updated_impl_params;
}

kernel_impl_params canonicalize_shapes(const kernel_impl_params& impl_params) const override {
return static_canonicalize_shapes(impl_params);
}
};

namespace detail {

attach_search_sorted_impl::attach_search_sorted_impl() {
auto types = {
data_types::i8,
data_types::u8,
data_types::i16,
data_types::u16,
data_types::i32,
data_types::u32,
data_types::i64,
data_types::f16,
data_types::f32,
};

auto formats = {format::bfyx, format::bfzyx};

implementation_map<search_sorted>::add(impl_types::ocl,
shape_types::any,
typed_primitive_impl_ocl<search_sorted>::create<search_sorted_impl>,
types,
formats);
}

} // namespace detail
} // namespace ocl
} // namespace cldnn

BIND_BINARY_BUFFER_WITH_TYPE(cldnn::ocl::search_sorted_impl)
BIND_BINARY_BUFFER_WITH_TYPE(cldnn::search_sorted)
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,4 @@ REGISTER_DEFAULT_IMPLS(unique_count, OCL_S, OCL_D);
REGISTER_DEFAULT_IMPLS(unique_gather, OCL_S, OCL_D);
REGISTER_DEFAULT_IMPLS(scaled_dot_product_attention, OCL_S, OCL_D);
REGISTER_DEFAULT_IMPLS(rope, OCL_S, OCL_D);
REGISTER_DEFAULT_IMPLS(search_sorted, OCL_S, OCL_D);
46 changes: 46 additions & 0 deletions src/plugins/intel_gpu/src/graph/include/search_sorted_inst.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once

#include <intel_gpu/primitives/search_sorted.hpp>

#include "primitive_inst.h"

namespace cldnn {

template <>
struct typed_program_node<search_sorted> : public typed_program_node_base<search_sorted> {
using parent = typed_program_node_base<search_sorted>;
typed_program_node(const std::shared_ptr<search_sorted> prim, program& prog) : parent(prim, prog) {}

public:
using parent::parent;

program_node& input(size_t idx = 0) const {
return get_dependency(idx);
}
std::vector<size_t> get_shape_infer_dependencies() const override {
return {};
}
};

using search_sorted_node = typed_program_node<search_sorted>;

template <>
class typed_primitive_inst<search_sorted> : public typed_primitive_inst_base<search_sorted> {
using parent = typed_primitive_inst_base<search_sorted>;
using parent::parent;

public:
typed_primitive_inst(network& network, search_sorted_node const& desc);
template <typename ShapeType>
static std::vector<layout> calc_output_layouts(search_sorted_node const& node,
kernel_impl_params const& impl_param);
static layout calc_output_layout(search_sorted_node const& node, kernel_impl_params const& impl_param);
static std::string to_string(search_sorted_node const& node);
};

using search_sorted_inst = typed_primitive_inst<search_sorted>;

} // namespace cldnn
59 changes: 59 additions & 0 deletions src/plugins/intel_gpu/src/graph/search_sorted.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <json_object.h>
#include <search_sorted_inst.h>

#include <sstream>

#include "openvino/core/enum_names.hpp"
#include "primitive_type_base.h"
#include "search_sorted_shape_inference.hpp"

namespace cldnn {
GPU_DEFINE_PRIMITIVE_TYPE_ID(search_sorted)

search_sorted_inst::typed_primitive_inst(network& network, search_sorted_node const& node) : parent(network, node) {}

layout search_sorted_inst::calc_output_layout(search_sorted_node const& node, kernel_impl_params const& impl_param) {
return calc_output_layouts<ov::PartialShape>(node, impl_param)[0];
}

template <typename ShapeType>
std::vector<layout> search_sorted_inst::calc_output_layouts(search_sorted_node const& node,
kernel_impl_params const& impl_param) {
auto primitive = impl_param.typed_desc<search_sorted>();

auto input0_layout = impl_param.get_input_layout(0);
auto input1_layout = impl_param.get_input_layout(1);

const data_types output_type = impl_param.desc->output_data_types[0].value_or(data_types::i64);

std::vector<ShapeType> input_shapes = {
input0_layout.get<ShapeType>(), // sorted shape
input1_layout.get<ShapeType>(), // values shape
};

std::vector<ShapeType> output_shapes;

ov::op::v15::SearchSorted op;
op.set_right_mode(primitive->right_mode);
output_shapes = shape_infer(&op, input_shapes);

return {layout{output_shapes[0], output_type, input1_layout.format}};
}

std::string search_sorted_inst::to_string(search_sorted_node const& node) {
auto node_info = node.desc_to_json();
json_composite search_sorted_info;
search_sorted_info.add("sorted id", node.input(0).id());
search_sorted_info.add("values id", node.input(1).id());
search_sorted_info.add("right_mode", node.get_primitive()->right_mode);
node_info->add("search_sorted info", search_sorted_info);
std::stringstream primitive_description;
node_info->dump(primitive_description);
return primitive_description.str();
}

} // namespace cldnn
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "include/batch_headers/fetch_data.cl"

#if RIGHT_MODE == 0
#define CMP <=
#else
#define CMP <
#endif

OUTPUT_TYPE FUNC(binary_search_thread)(const INPUT0_TYPE search_val,
const __global INPUT0_TYPE* restrict sorted,
OUTPUT_TYPE sorted_begin_idx,
OUTPUT_TYPE sorted_end_idx) {
while(sorted_begin_idx != sorted_end_idx) {
const OUTPUT_TYPE half_offset = (sorted_end_idx-sorted_begin_idx)/2;
const OUTPUT_TYPE half_idx = sorted_begin_idx+half_offset;
const INPUT0_TYPE half_val = sorted[half_idx];
if ( search_val CMP half_val )
sorted_end_idx = half_idx;
else
sorted_begin_idx = half_idx + 1;
}

return sorted_begin_idx;
}

#undef CMP

KERNEL(search_sorted_ref)(
OPTIONAL_SHAPE_INFO_ARG
const __global INPUT0_TYPE* restrict sorted,
const __global INPUT1_TYPE* restrict values,
__global OUTPUT_TYPE* restrict output)
{
// INPUT0_TYPE has to be egual to INPUT1_TYPE
const int this_thread_idx = get_global_id(0);
const INPUT0_TYPE search_val = values[this_thread_idx];

const int SORTED_STRIDE = INPUT0_BATCH_NUM*INPUT0_FEATURE_NUM*INPUT0_SIZE_Y*INPUT0_SIZE_Z;

// NOTE: SORTED_STRIDE-1 handles here a special case when sorted is actually 1D
// tensor and values is ND tensor. In such case we effectively want sorted_offset
// to be 0.
const int sorted_offset = min(this_thread_idx/INPUT1_SIZE_X, SORTED_STRIDE-1);

OUTPUT_TYPE sorted_begin_idx = sorted_offset * INPUT0_SIZE_X;
const OUTPUT_TYPE idx = FUNC_CALL(binary_search_thread)(search_val,
sorted + sorted_begin_idx,
0,
INPUT0_SIZE_X);

output[this_thread_idx] = idx;
}
3 changes: 2 additions & 1 deletion src/plugins/intel_gpu/src/kernel_selector/common_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ enum class KernelType {
RMS,
SWIGLU,
ROPE,
DYNAMIC_QUANTIZE
DYNAMIC_QUANTIZE,
SEARCH_SORTED
};

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Loading

0 comments on commit c85e88b

Please sign in to comment.