From ab5766ac85b10264ddf69807aa5ef4d2621cde3e Mon Sep 17 00:00:00 2001 From: shubdas9902 Date: Tue, 10 Dec 2024 09:02:00 +0530 Subject: [PATCH 1/4] Updated the case.cpp file --- src/frontends/tensorflow/src/op/case.cpp | 72 ++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 src/frontends/tensorflow/src/op/case.cpp diff --git a/src/frontends/tensorflow/src/op/case.cpp b/src/frontends/tensorflow/src/op/case.cpp new file mode 100644 index 00000000000000..42d13887efc085 --- /dev/null +++ b/src/frontends/tensorflow/src/op/case.cpp @@ -0,0 +1,72 @@ +#include "openvino/frontend/tensorflow/node_context.hpp" +#include "openvino/frontend/tensorflow/translate_session.hpp" +#include "openvino/opsets/opset11.hpp" +#include "openvino/util/log.hpp" + +namespace ov { +namespace frontend { +namespace tensorflow { +namespace op { + +using namespace ov::opsets; + +OutputVector translate_case_op(const NodeContext& node) { + // Validate the operation type + auto op_type = node.get_op_type(); + TENSORFLOW_OP_VALIDATION(node, op_type == "Case", + "Internal error: incorrect usage of translate_case_op."); + + // Retrieve the number of branches and inputs + auto num_branches = node.get_attribute("branches"); + TENSORFLOW_OP_VALIDATION(node, num_branches > 0, + "[TensorFlow Frontend] Case operation must have at least one branch."); + + // The first input is the condition for selecting the branch + auto cond = node.get_input(0); + + // Create a list to store sub-graphs for the branches + std::vector> branch_graphs; + for (int i = 0; i < num_branches; ++i) { + std::string branch_name = "branch_" + std::to_string(i); + auto branch_body = node.get_attribute(branch_name); + + // Ensure that the branch model is correctly loaded + auto branch_model = node.get_translate_session()->get_body_ov_model(branch_body, node.get_inputs()); + TENSORFLOW_OP_VALIDATION(node, branch_model, + "[TensorFlow Frontend] Failed to retrieve body graph for branch: " + branch_name); + branch_graphs.push_back(branch_model); + } + + // Create the nested If operation to represent the Case operation + std::shared_ptr current_model = nullptr; + for (int i = num_branches - 1; i >= 0; --i) { + auto if_op = std::make_shared(cond); + if_op->set_then_body(branch_graphs[i]); + + if (current_model) { + if_op->set_else_body(current_model); + } else { + // Default empty else body + auto placeholder_model = std::make_shared(OutputVector{}, ParameterVector{}); + if_op->set_else_body(placeholder_model); + } + + current_model = if_op->get_body_model(); + } + + // Set the outputs and names + auto outputs = current_model->get_results(); + OutputVector ov_outputs; + for (size_t i = 0; i < outputs.size(); ++i) { + auto tensor = outputs[i]->output(0).get_tensor(); + tensor.set_names({node.get_name() + ":" + std::to_string(i)}); + ov_outputs.push_back(outputs[i]->output(0)); + } + + return ov_outputs; +} + +} // namespace op +} // namespace tensorflow +} // namespace frontend +} // namespace ov From ef01643837725689a9e4b7eb57815706092c48c3 Mon Sep 17 00:00:00 2001 From: shubdas9902 Date: Thu, 12 Dec 2024 09:07:05 +0530 Subject: [PATCH 2/4] Added case.hpp and loaded in op_table.cpp --- src/frontends/tensorflow/src/op/case.hpp | 18 ++++++++++++++++++ src/frontends/tensorflow/src/op_table.cpp | 3 +++ 2 files changed, 21 insertions(+) create mode 100644 src/frontends/tensorflow/src/op/case.hpp diff --git a/src/frontends/tensorflow/src/op/case.hpp b/src/frontends/tensorflow/src/op/case.hpp new file mode 100644 index 00000000000000..05bb076cb84b2b --- /dev/null +++ b/src/frontends/tensorflow/src/op/case.hpp @@ -0,0 +1,18 @@ +#ifndef CASE_HPP +#define CASE_HPP + +#include "openvino/frontend/tensorflow/node_context.hpp" + +namespace ov { +namespace frontend { +namespace tensorflow { +namespace op { + +OutputVector translate_case_op(const ov::frontend::tensorflow::NodeContext& node); + +} // namespace op +} // namespace tensorflow +} // namespace frontend +} // namespace ov + +#endif // CASE_HPP diff --git a/src/frontends/tensorflow/src/op_table.cpp b/src/frontends/tensorflow/src/op_table.cpp index 26b665c275bb48..924b49a9c62e9b 100644 --- a/src/frontends/tensorflow/src/op_table.cpp +++ b/src/frontends/tensorflow/src/op_table.cpp @@ -19,6 +19,7 @@ #include "openvino/op/bitwise_or.hpp" #include "openvino/op/bitwise_right_shift.hpp" #include "openvino/op/bitwise_xor.hpp" +#include "case.hpp" #include "openvino/op/ceiling.hpp" #include "openvino/op/cos.hpp" #include "openvino/op/cosh.hpp" @@ -87,6 +88,7 @@ TF_OP_CONVERTER(translate_assignvariable_op); TF_OP_CONVERTER(translate_add_variable_op); TF_OP_CONVERTER(translate_sub_variable_op); TF_OP_CONVERTER(translate_block_lstm_op); +TF_OP_CONVERTER(translate_case_op); TF_OP_CONVERTER(translate_enter_op); TF_OP_CONVERTER(translate_exit_op); TF_OP_CONVERTER(translate_fifo_queue_op); @@ -140,6 +142,7 @@ const std::map get_supported_ops() { {"Asinh", CreatorFunction(translate_unary_op)}, {"Atan", CreatorFunction(translate_unary_op)}, {"Atanh", CreatorFunction(translate_unary_op)}, + {"Case", CreatorFunction(translate_case_op)}, {"Ceil", CreatorFunction(translate_unary_op)}, {"Cos", CreatorFunction(translate_unary_op)}, {"Cosh", CreatorFunction(translate_unary_op)}, From a87341b898d3994ffeafb09551b2a5ba7af8c83e Mon Sep 17 00:00:00 2001 From: shubdas9902 Date: Mon, 16 Dec 2024 20:48:57 +0530 Subject: [PATCH 3/4] Added Layer Tests --- .../tensorflow_tests/test_tf_Case_op.py | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 tests/layer_tests/tensorflow_tests/test_tf_Case_op.py diff --git a/tests/layer_tests/tensorflow_tests/test_tf_Case_op.py b/tests/layer_tests/tensorflow_tests/test_tf_Case_op.py new file mode 100644 index 00000000000000..d69f8185786375 --- /dev/null +++ b/tests/layer_tests/tensorflow_tests/test_tf_Case_op.py @@ -0,0 +1,53 @@ +import numpy as np +import pytest +import tensorflow as tf +from common.tf_layer_test_class import CommonTFLayerTest + + +class TestCaseOp(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + assert 'cond' in inputs_info + assert 'input_data' in inputs_info + inputs_data = { + 'cond': np.array(inputs_info['cond'], dtype=np.bool_), + 'input_data': np.array(inputs_info['input_data'], dtype=np.float32) + } + return inputs_data + + def create_case_net(self, input_shape, branches, default_branch): + tf.compat.v1.reset_default_graph() + with tf.compat.v1.Session() as sess: + # Inputs + cond = tf.compat.v1.placeholder(dtype=tf.bool, shape=(), name="cond") + input_data = tf.compat.v1.placeholder(dtype=tf.float32, shape=input_shape, name="input_data") + + # Define branch functions + def branch_fn_1(): + return tf.add(input_data, tf.constant(1.0, dtype=tf.float32)) + + def branch_fn_2(): + return tf.multiply(input_data, tf.constant(2.0, dtype=tf.float32)) + + branches_fn = [branch_fn_1, branch_fn_2] + + # Create Case operation + case_op = tf.raw_ops.Case(branch_index=cond, branches=branches_fn, output_type=tf.float32) + + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + return tf_net, None + + test_data_basic = [ + dict(input_shape=[1, 2], branches=2, default_branch=None, cond=True), + dict(input_shape=[3, 3], branches=2, default_branch=None, cond=False), + ] + + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + def test_case_op(self, params, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + self._test(*self.create_case_net(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api) From ef61090cff7ac7b9f3d75ca1db2128d748b912b1 Mon Sep 17 00:00:00 2001 From: shubdas9902 Date: Mon, 16 Dec 2024 21:06:52 +0530 Subject: [PATCH 4/4] Add layer test for Case operation --- .../tensorflow_tests/test_tf_Case_op.py | 28 +++++++++++++++---- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/tests/layer_tests/tensorflow_tests/test_tf_Case_op.py b/tests/layer_tests/tensorflow_tests/test_tf_Case_op.py index d69f8185786375..12a7815247a26a 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_Case_op.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_Case_op.py @@ -6,6 +6,9 @@ class TestCaseOp(CommonTFLayerTest): def _prepare_input(self, inputs_info): + """ + Prepares input data based on the given input shapes and data types. + """ assert 'cond' in inputs_info assert 'input_data' in inputs_info inputs_data = { @@ -14,7 +17,17 @@ def _prepare_input(self, inputs_info): } return inputs_data - def create_case_net(self, input_shape, branches, default_branch): + def create_case_net(self, input_shape, cond_value): + """ + Creates a TensorFlow model with a Case operation. + + Args: + input_shape: Shape of the input tensor. + cond_value: The condition value to select the branch. + + Returns: + TensorFlow graph definition and None. + """ tf.compat.v1.reset_default_graph() with tf.compat.v1.Session() as sess: # Inputs @@ -32,22 +45,27 @@ def branch_fn_2(): # Create Case operation case_op = tf.raw_ops.Case(branch_index=cond, branches=branches_fn, output_type=tf.float32) - + tf.identity(case_op, name="output") + tf.compat.v1.global_variables_initializer() tf_net = sess.graph_def return tf_net, None + # Test parameters test_data_basic = [ - dict(input_shape=[1, 2], branches=2, default_branch=None, cond=True), - dict(input_shape=[3, 3], branches=2, default_branch=None, cond=False), + dict(input_shape=[1, 2], cond=True), + dict(input_shape=[3, 3], cond=False), ] @pytest.mark.parametrize("params", test_data_basic) @pytest.mark.precommit_tf_fe @pytest.mark.nightly - def test_case_op(self, params, ie_device, precision, ir_version, temp_dir, + def test_case_op(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend, use_old_api): + """ + Executes the test for the Case operation. + """ self._test(*self.create_case_net(**params), ie_device, precision, ir_version, temp_dir=temp_dir, use_new_frontend=use_new_frontend, use_old_api=use_old_api)