Skip to content

Commit

Permalink
Get 'depth' value from then second input ('Select' node) of OneHot no…
Browse files Browse the repository at this point in the history
…de during inference

Signed-off-by: yuan.xiong <[email protected]>
  • Loading branch information
yuanxion committed Nov 21, 2024
1 parent f841f96 commit f0bc8f1
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 19 deletions.
15 changes: 15 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/primitives/one_hot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,21 @@ struct one_hot : public primitive_base<one_hot> {
, on_value(on_value)
, off_value(off_value) {}

/// @brief onehot with depth from Select node
one_hot(const primitive_id& id,
const input_info& input,
const input_info& input_depth,
const tensor& shape,
const data_types output_dt,
const int64_t& one_hot_axis,
const float& on_value = 1.0f,
const float& off_value = 0.0f)
: primitive_base(id, {input, input_depth})
, shape(shape)
, one_hot_axis(one_hot_axis)
, on_value(on_value)
, off_value(off_value) {}

/// @brief Constructs one-hot primitive layer.
/// @param id An identifier of new primitive.
/// @param input An identifier of primitive which is an input for newly created one-hot primitive.
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_gpu/src/graph/include/one_hot_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct typed_program_node<one_hot> : typed_program_node_base<one_hot> {
support_padding_all(true);
}
program_node& input() const { return get_dependency(0); }
std::vector<size_t> get_shape_infer_dependencies() const override { return {}; }
std::vector<size_t> get_shape_infer_dependencies() const override { return {1}; }
};

using one_hot_node = typed_program_node<one_hot>;
Expand Down
13 changes: 13 additions & 0 deletions src/plugins/intel_gpu/src/graph/one_hot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,19 @@ std::vector<layout> one_hot_inst::calc_output_layouts(const one_hot_node& /*node
std::unordered_map<size_t, ov::Tensor> const_data = {
{1, depth_tensor}
};

auto& memory_deps = impl_param.memory_deps;
if (memory_deps.count(1) > 0) {
auto depth_mem = memory_deps.at(1);

cldnn::mem_lock<uint8_t, mem_lock_type::read> depth_lock(depth_mem, impl_param.get_stream());
auto depth_ptr = depth_lock.data();

// update depth_tensor if depth value comes from memory_deps instead of Constant node
auto depth_tensor = make_tensor(depth_mem->get_layout(), depth_ptr);
const_data[1] = depth_tensor;
}

std::vector<ShapeType> output_shapes =
ov::op::v1::shape_infer(&op, input_shapes, ov::make_tensor_accessor(const_data));
return {{output_shapes[0], dt, format::get_default_format(output_shapes[0].size())}};
Expand Down
4 changes: 2 additions & 2 deletions src/plugins/intel_gpu/src/graph/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ void program::prepare_nodes(topology const& topology) {
}
}

// add node's dependecies from its primitive dependencies
// add node's dependencies from its primitive dependencies
void program::add_node_dependencies(program_node* node) {
auto deps = node->get_primitive()->dependencies();
// add pointers to node's dependencies
Expand All @@ -453,7 +453,7 @@ void program::add_node_dependencies(program_node* node) {
}

/* helper method for program constructor from list of nodes which
copies src_node dependecies to the destination node dest_node dependencies.
copies src_node dependencies to the destination node dest_node dependencies.
But only to those which appaer in this program implementation nodes_map */
void program::copy_node_dependencies(program_node* dest_node, program_node* src_node) {
if (dest_node->get_primitive()->id != src_node->get_primitive()->id) {
Expand Down
40 changes: 24 additions & 16 deletions src/plugins/intel_gpu/src/plugin/ops/one_hot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "transformations/utils/utils.hpp"

#include "openvino/op/one_hot.hpp"

#include "intel_gpu/primitives/one_hot.hpp"

namespace ov {
Expand Down Expand Up @@ -49,24 +48,33 @@ static void CreateOneHotOp(ProgramBuilder& p, const std::shared_ptr<ov::op::v1::
}
}

int64_t depth = 0;
if (depth_value_node) {
depth = depth_value_node->cast_vector<int64_t>()[0];
}

auto out_pshape = op->get_output_partial_shape(0);
cldnn::tensor out_tensor = out_pshape.is_static() ? tensor_from_dims(out_pshape.to_shape()) : cldnn::tensor{};

auto oneHotPrim = cldnn::one_hot(layerName,
inputs[0],
out_tensor,
cldnn::element_type_to_data_type(op->get_output_element_type(0)),
axis,
depth,
on_value,
off_value);

p.add_primitive(*op, oneHotPrim);
if (depth_value_node) {
int64_t depth = depth_value_node->cast_vector<int64_t>()[0];
auto oneHotPrim = cldnn::one_hot(layerName,
inputs[0],
out_tensor,
cldnn::element_type_to_data_type(op->get_output_element_type(0)),
axis,
depth,
on_value,
off_value);

p.add_primitive(*op, oneHotPrim);
} else {
auto oneHotPrim = cldnn::one_hot(layerName,
inputs[0],
inputs[1],
out_tensor,
cldnn::element_type_to_data_type(op->get_output_element_type(0)),
axis,
on_value,
off_value);

p.add_primitive(*op, oneHotPrim);
}
}

REGISTER_FACTORY_IMPL(v1, OneHot);
Expand Down

0 comments on commit f0bc8f1

Please sign in to comment.