Skip to content

Commit

Permalink
[fix] code style
Browse files Browse the repository at this point in the history
  • Loading branch information
cocoshe committed Nov 20, 2024
1 parent 9e7eb91 commit 04ddd03
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/frontends/pytorch/src/op/index_fill_.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,16 @@ OutputVector translate_index_fill_(const NodeContext& context) {
auto index_len = context.mark_node(std::make_shared<v8::Slice>(index_shape, const_0, const_1, const_1));

// [A, B, ..., T, ..., K] --> [A, B, ..., len(index), ..., K]
auto target_shape = std::make_shared<v12::ScatterElementsUpdate>(input_shape,
dim_vec,
index_len,
v0::Constant::create(element::i32, Shape{}, {0}));
auto target_shape = std::make_shared<v12::ScatterElementsUpdate>(input_shape,
dim_vec,
index_len,
v0::Constant::create(element::i32, Shape{}, {0}));

// broadcast && index fill
auto broadcasted_value = context.mark_node(std::make_shared<v1::Broadcast>(value_vec, target_shape, dim_vec));
auto broadcasted_index = context.mark_node(std::make_shared<v1::Broadcast>(index, target_shape, dim_vec));
auto result = context.mark_node(std::make_shared<v12::ScatterElementsUpdate>(input, broadcasted_index, broadcasted_value, dim));
auto result = context.mark_node(
std::make_shared<v12::ScatterElementsUpdate>(input, broadcasted_index, broadcasted_value, dim));

return {result};
};
Expand Down

0 comments on commit 04ddd03

Please sign in to comment.