Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TF FE] Support Case operation #28027

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions src/frontends/tensorflow/src/op/case.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#include "openvino/frontend/tensorflow/node_context.hpp"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add copyright as for other src files.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#include "openvino/frontend/tensorflow/translate_session.hpp"
#include "openvino/opsets/opset11.hpp"
#include "openvino/util/log.hpp"

namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {

using namespace ov::opsets;

OutputVector translate_case_op(const NodeContext& node) {
// Validate the operation type
auto op_type = node.get_op_type();
TENSORFLOW_OP_VALIDATION(node, op_type == "Case",
"Internal error: incorrect usage of translate_case_op.");
Comment on lines +14 to +17
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use default_op_checks instead


// Retrieve the number of branches and inputs
auto num_branches = node.get_attribute<int>("branches");
TENSORFLOW_OP_VALIDATION(node, num_branches > 0,
"[TensorFlow Frontend] Case operation must have at least one branch.");

// The first input is the condition for selecting the branch
auto cond = node.get_input(0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let us rename it to branch_index


// Create a list to store sub-graphs for the branches
std::vector<std::shared_ptr<Model>> branch_graphs;
for (int i = 0; i < num_branches; ++i) {
std::string branch_name = "branch_" + std::to_string(i);
auto branch_body = node.get_attribute<std::string>(branch_name);

// Ensure that the branch model is correctly loaded
auto branch_model = node.get_translate_session()->get_body_ov_model(branch_body, node.get_inputs());
TENSORFLOW_OP_VALIDATION(node, branch_model,
"[TensorFlow Frontend] Failed to retrieve body graph for branch: " + branch_name);
branch_graphs.push_back(branch_model);
}

// Create the nested If operation to represent the Case operation
std::shared_ptr<Model> current_model = nullptr;
for (int i = num_branches - 1; i >= 0; --i) {
auto if_op = std::make_shared<If>(cond);
if_op->set_then_body(branch_graphs[i]);

if (current_model) {
if_op->set_else_body(current_model);
} else {
// Default empty else body
auto placeholder_model = std::make_shared<Model>(OutputVector{}, ParameterVector{});
if_op->set_else_body(placeholder_model);
}
Comment on lines +45 to +52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see that you properly set input parameters for then and else bodies. please check implementation of If translators. The same question for outputs of bodies.

Also, you need to have different conditions for each nested If. It is like a sort of branch_index == i where i is an index of branch.


current_model = if_op->get_body_model();
}

// Set the outputs and names
auto outputs = current_model->get_results();
OutputVector ov_outputs;
for (size_t i = 0; i < outputs.size(); ++i) {
auto tensor = outputs[i]->output(0).get_tensor();
tensor.set_names({node.get_name() + ":" + std::to_string(i)});
ov_outputs.push_back(outputs[i]->output(0));
}

return ov_outputs;
}

} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov
18 changes: 18 additions & 0 deletions src/frontends/tensorflow/src/op/case.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef CASE_HPP
#define CASE_HPP

#include "openvino/frontend/tensorflow/node_context.hpp"

namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {

OutputVector translate_case_op(const ov::frontend::tensorflow::NodeContext& node);

} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov

#endif // CASE_HPP
Comment on lines +1 to +18
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed header file. Please check other operation translators for which we don't have header

3 changes: 3 additions & 0 deletions src/frontends/tensorflow/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "openvino/op/bitwise_or.hpp"
#include "openvino/op/bitwise_right_shift.hpp"
#include "openvino/op/bitwise_xor.hpp"
#include "case.hpp"
#include "openvino/op/ceiling.hpp"
#include "openvino/op/cos.hpp"
#include "openvino/op/cosh.hpp"
Expand Down Expand Up @@ -87,6 +88,7 @@ TF_OP_CONVERTER(translate_assignvariable_op);
TF_OP_CONVERTER(translate_add_variable_op);
TF_OP_CONVERTER(translate_sub_variable_op);
TF_OP_CONVERTER(translate_block_lstm_op);
TF_OP_CONVERTER(translate_case_op);
TF_OP_CONVERTER(translate_enter_op);
TF_OP_CONVERTER(translate_exit_op);
TF_OP_CONVERTER(translate_fifo_queue_op);
Expand Down Expand Up @@ -140,6 +142,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"Asinh", CreatorFunction(translate_unary_op<v3::Asinh>)},
{"Atan", CreatorFunction(translate_unary_op<v0::Atan>)},
{"Atanh", CreatorFunction(translate_unary_op<v3::Atanh>)},
{"Case", CreatorFunction(translate_case_op)},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add layer tests for this Operation

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rkazants It would be greatly appreciated if you could guide me on the appropriate folder where I should add the layer tests for this operation. As I am new to open-source contributions, your guidance would be very helpful in ensuring I follow the correct structure and conventions for adding the test case for this operation.

{"Ceil", CreatorFunction(translate_unary_op<v0::Ceiling>)},
{"Cos", CreatorFunction(translate_unary_op<v0::Cos>)},
{"Cosh", CreatorFunction(translate_unary_op<v0::Cosh>)},
Expand Down
Loading