Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
evkotov committed Apr 1, 2024
1 parent 9adfd6c commit 126d6fe
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 1 deletion.
164 changes: 163 additions & 1 deletion src/frontends/tensorflow/tests/convert_tricky_models.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "common_test_utils/test_common.hpp"
#include "conversion_with_reference.hpp"
#include "gtest/gtest.h"
#include "openvino//op/less_eq.hpp"
#include "openvino/frontend/exception.hpp"
#include "openvino/frontend/extension.hpp"
#include "openvino/frontend/manager.hpp"
Expand Down Expand Up @@ -852,3 +852,165 @@ TEST_F(FrontEndConversionWithReferenceTestsF, UnitializedVariableV2AsInput) {
model_ref = make_shared<Model>(OutputVector{mul}, ParameterVector{x, var});
}
}

TEST_F(FrontEndConversionWithReferenceTestsF, LoopWithInvariant) {
std::shared_ptr<Model> real_loop_model;
{ model = convert_model("loop_with_invariant/loop_with_invariant.pb"); }
{
ParameterVector model_inputs;
std::shared_ptr<ov::Node> loop_exec_condition;
{
auto input = make_shared<v0::Parameter>(element::i32, Shape{1});
auto less_eq_const = make_shared<v0::Constant>(element::i32, Shape{}, Shape{1});
auto less_eq = make_shared<v1::LessEqual>(less_eq_const, input);
auto squeeze_const = make_shared<v0::Constant>(element::i32, Shape{0});
auto squeeze = make_shared<v0::Squeeze>(less_eq, squeeze_const);
loop_exec_condition = squeeze;

model_inputs.emplace_back(input);
}

ParameterVector body_model_inputs;
OutputVector body_model_outputs;
std::shared_ptr<ov::Model> loop_body;
std::shared_ptr<v0::Result> while_identity;
{
std::shared_ptr<v0::Parameter> while_while_loop_counter;
std::shared_ptr<v0::Result> while_identity_5;
{
while_while_loop_counter = make_shared<v0::Parameter>(element::i32, Shape{});
auto input_const = make_shared<v0::Constant>(element::i32, Shape{}, Shape{1});
auto add = make_shared<v1::Add>(while_while_loop_counter, input_const);
while_identity_5 = make_shared<v0::Result>(add);

while_while_loop_counter->set_friendly_name("while_while_loop_counter");
input_const->set_friendly_name("while/add_1/y");
add->set_friendly_name("while/add_1");
while_identity_5->set_friendly_name("while_identity_5");
}
std::shared_ptr<v0::Parameter> while_while_maximum_iterations;
std::shared_ptr<v0::Result> while_identity_1;
{
while_while_maximum_iterations = make_shared<v0::Parameter>(element::i32, Shape{});
while_identity_1 = make_shared<v0::Result>(while_while_maximum_iterations);

while_while_maximum_iterations->set_friendly_name("while_while_maximum_iterations");
while_identity_1->set_friendly_name("while_identity_1");
}
std::shared_ptr<v0::Parameter> while_placeholder, while_placeholder_1;
std::shared_ptr<v0::Result> while_identity_3;
while_placeholder = make_shared<v0::Parameter>(element::i32, Shape{});
{
while_placeholder_1 = make_shared<v0::Parameter>(element::i32, Shape{});
auto multiply = make_shared<v1::Multiply>(while_placeholder_1, while_placeholder);
while_identity_3 = make_shared<v0::Result>(multiply);

while_placeholder->set_friendly_name("while_placeholder");
while_placeholder_1->set_friendly_name("while_placeholder_1");
multiply->set_friendly_name("while/mul");
while_identity_3->set_friendly_name("while_identity_3");
}
std::shared_ptr<v0::Parameter> while_placeholder_2;
std::shared_ptr<v0::Result> while_identity_4;
{
while_placeholder_2 = make_shared<v0::Parameter>(element::i32, Shape{});
while_identity_4 = make_shared<v0::Result>(while_placeholder_2);

while_placeholder_2->set_friendly_name("while_placeholder_2");
while_identity_4->set_friendly_name("while_identity_4");
}
std::shared_ptr<v0::Parameter> while_n_0;
std::shared_ptr<v0::Result> while_n;
{
while_n_0 = make_shared<v0::Parameter>(element::i32, Shape{1});
while_n = make_shared<v0::Result>(while_n_0);

while_n_0->set_friendly_name("while_n_0");
while_n->set_friendly_name("while_n");
}
std::shared_ptr<v0::Result> while_identity_2;
{
auto input_const = make_shared<v0::Constant>(element::i32, Shape{}, Shape{1});
auto add = make_shared<v1::Add>(while_placeholder, input_const);

while_identity_2 = make_shared<v0::Result>(add);

auto less_eq = make_shared<v1::LessEqual>(add, while_n_0);
auto squeeze_const = make_shared<v0::Constant>(element::i32, Shape{0});
auto squeeze = make_shared<v0::Squeeze>(less_eq, squeeze_const);
while_identity = make_shared<v0::Result>(squeeze);

input_const->set_friendly_name("while/add/y");
add->set_friendly_name("while/add");
while_identity_2->set_friendly_name("while_identity_2");
less_eq->set_friendly_name("while/LessEqual_1");
squeeze_const->set_friendly_name("Constant_14_1");
squeeze->set_friendly_name("while/Squeeze_1");
while_identity->set_friendly_name("while_identity");
}

body_model_inputs = ParameterVector({while_while_loop_counter,
while_while_maximum_iterations,
while_placeholder,
while_placeholder_1,
while_placeholder_2,
while_n_0});
body_model_outputs = OutputVector({while_identity_5,
while_identity_1,
while_identity_2,
while_identity_3,
while_identity_4,
while_n,
while_identity});

loop_body = make_shared<Model>(body_model_outputs, body_model_inputs);
}
auto trip_count_input = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, -1);

auto loop = std::make_shared<ov::op::v5::Loop>(trip_count_input, loop_exec_condition);
loop->set_function(loop_body);
loop->set_special_body_ports(ov::op::v5::Loop::SpecialBodyPorts{-1, 6});

{
auto input_const = make_shared<v0::Constant>(element::i32, Shape{}, 0);
loop->set_merged_input(body_model_inputs[0], input_const, body_model_outputs[0]);
}
{
auto input_const = make_shared<v0::Constant>(element::i32, Shape{}, 1);
loop->set_merged_input(body_model_inputs[2], input_const, body_model_outputs[2]);
}
{
auto input_const = make_shared<v0::Constant>(element::i32, Shape{}, 1);
loop->set_merged_input(body_model_inputs[3], input_const, body_model_outputs[3]);
}

{
auto input_const = make_shared<v0::Constant>(element::i32, Shape{}, -1);
loop->set_invariant_input(body_model_inputs[1], input_const);
}
{
auto input_const = make_shared<v0::Constant>(element::i32, Shape{}, 1);
loop->set_invariant_input(body_model_inputs[4], input_const);
}

{
auto input = make_shared<v0::Parameter>(element::i32, Shape{1});
loop->set_invariant_input(body_model_inputs[5], input);
model_inputs.emplace_back(input);
}

for (const auto& loop_body_output : body_model_outputs) {
if (loop_body_output == while_identity)
continue;
loop->get_iter_value(loop_body_output);
}

auto const_multiply = make_shared<v0::Constant>(element::i32, Shape{}, Shape{1});
auto multiply = make_shared<v1::Multiply>(loop->output(3), const_multiply);
auto result = make_shared<v0::Result>(multiply);

model_ref = make_shared<Model>(OutputVector{result}, model_inputs);
}

comparator.disable(FunctionsComparator::CmpValues::ATTRIBUTES);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/usr/bin/python3

import sys
import tensorflow as tf
import os


@tf.function(input_signature=[tf.TensorSpec(shape=[1], dtype=tf.int32)])
def while_loop_with_invariant(n):
invariant_var_input = tf.constant(1, dtype=tf.int32)

def condition(counter, total_product, invariant_var_arg):
return counter <= n

def body(counter, total_product, invariant_var_arg):
return [
counter + 1,
total_product * counter,
invariant_var_arg
]

# Initial values
counter_initial = tf.constant(1, dtype=tf.int32)
total_product_initial = tf.constant(1, dtype=tf.int32)

# Execute the loop
_, final_prod, invariant_var_output = tf.while_loop(
condition,
body,
[counter_initial, total_product_initial, invariant_var_input],
shape_invariants=[
tf.TensorShape([1]),
tf.TensorShape([1]),
tf.TensorShape([1])
]
)

result = final_prod * invariant_var_output
return result


tf_net = while_loop_with_invariant.get_concrete_function(n=tf.TensorSpec(shape=[1], dtype=tf.int32)).graph
tf.io.write_graph(tf_net, os.path.join(sys.argv[1], 'loop_with_invariant'), 'loop_with_invariant.pb', as_text=False)

0 comments on commit 126d6fe

Please sign in to comment.