Skip to content

Commit

Permalink
fix loop output init
Browse files Browse the repository at this point in the history
  • Loading branch information
evkotov committed Apr 2, 2024
1 parent 5fb6194 commit 70dce9c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/frontends/tensorflow/src/tf_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,11 +579,11 @@ ov::OutputVector create_loop_for_tf_while(const std::string& while_node_name,
// set external outputs for Loop node
// do not get execution condition outside of the Loop node
for (size_t output_ind = 0; output_ind < body_condition_output_idx; ++output_ind) {
const auto loop_output = loop->get_iter_value(body_results[output_ind]);
auto parent_parameter = get_parent_parameter(body_results[output_ind]);
if (parent_parameter) {
loop_outputs.push_back(loop_input_nodes[parent_parameter->get_instance_id()]);
} else {
const auto loop_output = loop->get_iter_value(body_results[output_ind]);
loop_outputs.push_back(loop_output);
}
}
Expand Down
13 changes: 10 additions & 3 deletions src/frontends/tensorflow/tests/convert_tricky_models.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,6 @@ TEST_F(FrontEndConversionWithReferenceTestsF, UnitializedVariableV2AsInput) {
}

TEST_F(FrontEndConversionWithReferenceTestsF, LoopWithInvariant) {
std::shared_ptr<Model> real_loop_model;
{ model = convert_model("loop_with_invariant/loop_with_invariant.pb"); }
{
ParameterVector model_inputs;
Expand Down Expand Up @@ -971,17 +970,21 @@ TEST_F(FrontEndConversionWithReferenceTestsF, LoopWithInvariant) {
loop->set_function(loop_body);
loop->set_special_body_ports(ov::op::v5::Loop::SpecialBodyPorts{-1, 6});

OutputVector merged_outputs;
{
auto input_const = make_shared<v0::Constant>(element::i32, Shape{}, 0);
loop->set_merged_input(body_model_inputs[0], input_const, body_model_outputs[0]);
merged_outputs.emplace_back(body_model_outputs[0]);
}
{
auto input_const = make_shared<v0::Constant>(element::i32, Shape{}, 1);
loop->set_merged_input(body_model_inputs[2], input_const, body_model_outputs[2]);
merged_outputs.emplace_back(body_model_outputs[2]);
}
{
auto input_const = make_shared<v0::Constant>(element::i32, Shape{}, 1);
loop->set_merged_input(body_model_inputs[3], input_const, body_model_outputs[3]);
merged_outputs.emplace_back(body_model_outputs[3]);
}

{
Expand All @@ -1000,13 +1003,17 @@ TEST_F(FrontEndConversionWithReferenceTestsF, LoopWithInvariant) {
}

for (const auto& loop_body_output : body_model_outputs) {
if (loop_body_output == while_identity)
const auto it =
std::find_if(merged_outputs.begin(), merged_outputs.end(), [loop_body_output](Output<Node>& output) {
return output == loop_body_output;
});
if (it == merged_outputs.end())
continue;
loop->get_iter_value(loop_body_output);
}

auto const_multiply = make_shared<v0::Constant>(element::i32, Shape{}, Shape{1});
auto multiply = make_shared<v1::Multiply>(loop->output(3), const_multiply);
auto multiply = make_shared<v1::Multiply>(loop->output(2), const_multiply);
auto result = make_shared<v0::Result>(multiply);

model_ref = make_shared<Model>(OutputVector{result}, model_inputs);
Expand Down

0 comments on commit 70dce9c

Please sign in to comment.