Skip to content

Commit

Permalink
code review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
evkotov committed Apr 1, 2024
1 parent 126d6fe commit 5fb6194
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
4 changes: 3 additions & 1 deletion src/frontends/tensorflow/src/op/while.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ OutputVector translate_while_op(const NodeContext& node) {
body_model,
"[TensorFlow Frontend] Internal error or incorrect input model. Cannot find body graph with name " + body_type);

return create_loop_for_tf_while(node.get_name(), body_model, cond_model, ov_inputs);
auto loop_outputs = create_loop_for_tf_while(node.get_name(), body_model, cond_model, ov_inputs);
set_node_name(node.get_name(), loop_outputs[0].get_node_shared_ptr());
return loop_outputs;
}

} // namespace op
Expand Down
8 changes: 1 addition & 7 deletions src/frontends/tensorflow/src/tf_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,13 +467,8 @@ bool propagate_conditional_flow(const OutputVector& ov_inputs,
namespace {
shared_ptr<op::v0::Parameter> get_parent_parameter(const shared_ptr<op::v0::Result>& node) {
const auto input_values = node->input_values();
if (input_values.empty())
return {};
return as_type_ptr<v0::Parameter>(input_values[0].get_node_shared_ptr());
}
bool has_parent_parameter(const shared_ptr<op::v0::Result>& node) {
return get_parent_parameter(node) != nullptr;
}
} // namespace

// create Loop operation corresponding to TensorFlow While operation
Expand Down Expand Up @@ -558,7 +553,7 @@ ov::OutputVector create_loop_for_tf_while(const std::string& while_node_name,
std::vector<size_t> invariant_input_indexes;
// body_results may contain less nodes than body_params that means back edge exists not for all body_params
for (size_t input_ind = 0; input_ind < body_condition_output_idx; ++input_ind) {
if (has_parent_parameter(body_results[input_ind])) {
if (get_parent_parameter(body_results[input_ind])) {
invariant_input_indexes.push_back(input_ind);
continue;
}
Expand Down Expand Up @@ -594,7 +589,6 @@ ov::OutputVector create_loop_for_tf_while(const std::string& while_node_name,
}

loop->validate_and_infer_types();
set_node_name(while_node_name, loop);
return loop_outputs;
}

Expand Down

0 comments on commit 5fb6194

Please sign in to comment.