Skip to content

Commit

Permalink
Use empty const_data map to get dynamic output_shapes when no depth v…
Browse files Browse the repository at this point in the history
…alue

Signed-off-by: yuan.xiong <[email protected]>
  • Loading branch information
yuanxion committed Nov 21, 2024
1 parent f0bc8f1 commit 0ab78a2
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ std::vector<TRShape> shape_infer(const OneHot* op,
auto depth_as_shape =
get_input_const_data_as_shape<TRShape>(op, 1, ta, util::GetNotNegative<typename DimType::value_type>(op));

if (depth_as_shape && depth_as_shape->size() == 1 && (*depth_as_shape)[0].get_length() > 0) {
if (depth_as_shape && depth_as_shape->size() == 1) {
result_shape.insert(result_shape.begin() + axis, (*depth_as_shape)[0]);
} else {
result_shape.insert(result_shape.begin() + axis, DimType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ struct one_hot : public primitive_base<one_hot> {
const int64_t& one_hot_axis,
const float& on_value = 1.0f,
const float& off_value = 0.0f)
: primitive_base(id, {input, input_depth})
: primitive_base(id, {input, input_depth}, 1, {optional_data_type{output_dt}})
, shape(shape)
, one_hot_axis(one_hot_axis)
, on_value(on_value)
Expand Down
15 changes: 8 additions & 7 deletions src/plugins/intel_gpu/src/graph/one_hot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,15 @@ std::vector<layout> one_hot_inst::calc_output_layouts(const one_hot_node& /*node
};

int64_t depth = desc->depth;

auto depth_tensor = ov::Tensor(ov::element::i64, ov::Shape{1}, static_cast<void*>(&depth));
std::unordered_map<size_t, ov::Tensor> const_data = {
{1, depth_tensor}
};

auto& memory_deps = impl_param.memory_deps;
if (memory_deps.count(1) > 0) {

std::unordered_map<size_t, ov::Tensor> const_data = {};
if (depth != 0) {
auto depth_tensor = ov::Tensor(ov::element::i64, ov::Shape{1}, static_cast<void*>(&depth));
const_data = {
{1, depth_tensor}
};
} else if (memory_deps.count(1) > 0) {
auto depth_mem = memory_deps.at(1);

cldnn::mem_lock<uint8_t, mem_lock_type::read> depth_lock(depth_mem, impl_param.get_stream());
Expand Down

0 comments on commit 0ab78a2

Please sign in to comment.