Skip to content

Commit

Permalink
Revert extra interface changes
Browse files Browse the repository at this point in the history
  • Loading branch information
EgorDuplensky committed Nov 21, 2024
1 parent 7c3ce69 commit 3705e87
Show file tree
Hide file tree
Showing 12 changed files with 104 additions and 216 deletions.
16 changes: 8 additions & 8 deletions src/common/transformations/include/ov_ops/fully_connected.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,25 @@ class TRANSFORMATIONS_API FullyConnected : public ov::op::Op {

FullyConnected() = default;

FullyConnected(const OutputVector& arguments, const ov::element::Type output_type = ov::element::undefined);

FullyConnected(const ov::Output<Node>& A,
const ov::Output<Node>& B,
const ov::Output<Node>& bias,
const ov::element::Type output_type = ov::element::undefined);

bool visit_attributes(ov::AttributeVisitor& visitor) override;

void validate_and_infer_types() override;

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;
FullyConnected(const ov::Output<Node>& A,
const ov::Output<Node>& B,
const ov::element::Type output_type = ov::element::undefined);

virtual std::shared_ptr<Node> fuse_bias(const ov::Output<Node>& bias) const;
bool visit_attributes(ov::AttributeVisitor& visitor) override;

ov::element::Type get_output_type() const {
return m_output_type;
}

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

void validate_and_infer_types() override;

protected:
ov::element::Type m_output_type;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,22 @@ class TRANSFORMATIONS_API FullyConnectedCompressed : public FullyConnected {

FullyConnectedCompressed() = default;

FullyConnectedCompressed(const OutputVector& arguments,
const ov::element::Type output_type = ov::element::undefined);

FullyConnectedCompressed(const ov::Output<Node>& X,
const ov::Output<Node>& W,
const ov::Output<Node>& bias,
const ov::Output<Node>& weight_scales,
const ov::Output<Node>& weight_zero_points,
const ov::element::Type output_type = ov::element::undefined);

void validate_and_infer_types() override;
FullyConnectedCompressed(const ov::Output<Node>& X,
const ov::Output<Node>& W,
const ov::Output<Node>& bias,
const ov::Output<Node>& weight_scales,
const ov::element::Type output_type = ov::element::undefined);

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

std::shared_ptr<Node> fuse_bias(const ov::Output<Node>& bias) const override final;
void validate_and_infer_types() override;
};

} // namespace internal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ class TRANSFORMATIONS_API FullyConnectedQuantized : public FullyConnected {

FullyConnectedQuantized() = default;

FullyConnectedQuantized(const OutputVector& arguments,
const ov::element::Type output_type = ov::element::undefined);

FullyConnectedQuantized(const ov::Output<Node>& X,
const ov::Output<Node>& W,
const ov::Output<Node>& bias,
Expand All @@ -32,36 +29,9 @@ class TRANSFORMATIONS_API FullyConnectedQuantized : public FullyConnected {
const ov::Output<Node>& output_zero_points,
const ov::element::Type output_type = ov::element::undefined);

FullyConnectedQuantized(const ov::Output<Node>& X,
const ov::Output<Node>& W,
const ov::Output<Node>& bias,
const ov::Output<Node>& weight_scales,
const ov::Output<Node>& weight_zero_points,
const ov::Output<Node>& input_scales,
const ov::element::Type output_type = ov::element::undefined);

FullyConnectedQuantized(const ov::Output<Node>& X,
const ov::Output<Node>& W,
const ov::Output<Node>& bias,
const ov::Output<Node>& weight_scales,
const ov::Output<Node>& weight_zero_points,
const ov::element::Type output_type = ov::element::undefined);

FullyConnectedQuantized(const ov::Output<Node>& X,
const ov::Output<Node>& W,
const ov::Output<Node>& bias,
const ov::Output<Node>& weight_scales,
const ov::element::Type output_type = ov::element::undefined);

void validate_and_infer_types() override;

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

std::shared_ptr<Node> fuse_bias(const ov::Output<Node>& bias) const override final;

ov::element::Type get_output_type() const {
return m_output_type;
}
};

} // namespace internal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ class TRANSFORMATIONS_API FullyConnectedQuantizedLegacy : public FullyConnected

FullyConnectedQuantizedLegacy() = default;

FullyConnectedQuantizedLegacy(const OutputVector& arguments,
const ov::element::Type output_type = ov::element::undefined);

FullyConnectedQuantizedLegacy(const ov::Output<Node>& X,
const ov::Output<Node>& W,
const ov::Output<Node>& bias,
Expand All @@ -34,13 +31,9 @@ class TRANSFORMATIONS_API FullyConnectedQuantizedLegacy : public FullyConnected
const ov::Output<Node>& deq_scales,
const ov::element::Type output_type = ov::element::undefined);

bool visit_attributes(ov::AttributeVisitor& visitor) override;

void validate_and_infer_types() override;

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

std::shared_ptr<Node> fuse_bias(const ov::Output<Node>& bias) const override final;
void validate_and_infer_types() override;
};

} // namespace internal
Expand Down
26 changes: 11 additions & 15 deletions src/common/transformations/src/ov_ops/fully_connected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,12 @@
#include <memory>

#include "matmul_shape_inference.hpp"
#include "ov_ops/placeholder.hpp"

namespace ov {
namespace op {
namespace internal {

FullyConnected::FullyConnected(const OutputVector& arguments, const ov::element::Type output_type)
: Op(arguments),
m_output_type(output_type) {
validate_and_infer_types();
}

FullyConnected::FullyConnected(const ov::Output<Node>& A,
const ov::Output<Node>& B,
const ov::Output<Node>& bias,
Expand All @@ -27,16 +22,22 @@ FullyConnected::FullyConnected(const ov::Output<Node>& A,
validate_and_infer_types();
}

FullyConnected::FullyConnected(const ov::Output<Node>& A,
const ov::Output<Node>& B,
const ov::element::Type output_type)
: FullyConnected(A, B, std::make_shared<Placeholder>(), output_type) {}

bool FullyConnected::visit_attributes(ov::AttributeVisitor& visitor) {
visitor.on_attribute("output_type", m_output_type);
return true;
}

std::shared_ptr<ov::Node> FullyConnected::clone_with_new_inputs(const ov::OutputVector& new_args) const {
check_new_args_count(this, new_args);

return std::make_shared<FullyConnected>(new_args.at(0), new_args.at(1), new_args.at(2), m_output_type);
}

std::shared_ptr<Node> FullyConnected::fuse_bias(const ov::Output<Node>& bias) const {
return std::make_shared<FullyConnected>(input_value(0), input_value(1), bias, m_output_type);
}

void FullyConnected::validate_and_infer_types() {
const auto input_size = get_input_size();
NODE_VALIDATION_CHECK(this,
Expand All @@ -57,11 +58,6 @@ void FullyConnected::validate_and_infer_types() {
set_output_type(0, output_type, out_shapes[0]);
}

bool FullyConnected::visit_attributes(ov::AttributeVisitor& visitor) {
visitor.on_attribute("output_type", m_output_type);
return true;
}

} // namespace internal
} // namespace op
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,16 @@

#include "ov_ops/fully_connected_compressed.hpp"

#include <memory>

#include "openvino/core/type/element_type.hpp"
#include "ov_ops/fully_connected.hpp"
#include "ov_ops/placeholder.hpp"

namespace ov {
namespace op {
namespace internal {

FullyConnectedCompressed::FullyConnectedCompressed(const OutputVector& arguments, const ov::element::Type output_type)
: FullyConnected(OutputVector(arguments.begin(), arguments.begin() + 3), output_type) {
for (size_t i = 3; i < arguments.size(); i++) {
set_argument(i, arguments[i]);
}
validate_and_infer_types();
}

FullyConnectedCompressed::FullyConnectedCompressed(const ov::Output<Node>& X,
const ov::Output<Node>& W,
const ov::Output<Node>& bias,
Expand All @@ -31,27 +26,35 @@ FullyConnectedCompressed::FullyConnectedCompressed(const ov::Output<Node>& X,
validate_and_infer_types();
}

FullyConnectedCompressed::FullyConnectedCompressed(const ov::Output<Node>& X,
const ov::Output<Node>& W,
const ov::Output<Node>& bias,
const ov::Output<Node>& weight_scales,
const ov::element::Type output_type)
: FullyConnectedCompressed(X, W, bias, weight_scales, std::make_shared<Placeholder>(), output_type) {}

std::shared_ptr<ov::Node> FullyConnectedCompressed::clone_with_new_inputs(const ov::OutputVector& new_args) const {
check_new_args_count(this, new_args);

return std::make_shared<FullyConnectedCompressed>(new_args, m_output_type);
}

std::shared_ptr<Node> FullyConnectedCompressed::fuse_bias(const ov::Output<Node>& bias) const {
auto inputs = input_values();
inputs[2] = bias;

return std::make_shared<FullyConnectedCompressed>(inputs, get_output_type());
return std::make_shared<FullyConnectedCompressed>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
m_output_type);
}

// @todo finalize validate_and_infer_types
void FullyConnectedCompressed::validate_and_infer_types() {
const auto input_size = get_input_size();
const size_t expected_size = 5;
NODE_VALIDATION_CHECK(this,
input_size >= 4,
input_size == expected_size,
"Number of inputs is incorrect. Current value is: ",
input_size,
", expected at least 4.");
", expected at least ",
expected_size,
".");

FullyConnected::validate_and_infer_types();
}
Expand Down
67 changes: 15 additions & 52 deletions src/common/transformations/src/ov_ops/fully_connected_quantized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,6 @@ namespace ov {
namespace op {
namespace internal {

FullyConnectedQuantized::FullyConnectedQuantized(const OutputVector& arguments, const ov::element::Type output_type)
: FullyConnected(OutputVector(arguments.begin(), arguments.begin() + 3), output_type) {
for (size_t i = 3; i < arguments.size(); i++) {
set_argument(i, arguments[i]);
}
validate_and_infer_types();
}

FullyConnectedQuantized::FullyConnectedQuantized(const ov::Output<Node>& X,
const ov::Output<Node>& W,
const ov::Output<Node>& bias,
Expand All @@ -39,61 +31,32 @@ FullyConnectedQuantized::FullyConnectedQuantized(const ov::Output<Node>& X,
validate_and_infer_types();
}

FullyConnectedQuantized::FullyConnectedQuantized(const ov::Output<Node>& X,
const ov::Output<Node>& W,
const ov::Output<Node>& bias,
const ov::Output<Node>& weight_scales,
const ov::Output<Node>& weight_zero_points,
const ov::Output<Node>& input_scales,
const ov::element::Type output_type)
: FullyConnected(X, W, bias, output_type) {
set_argument(3, weight_scales);
set_argument(4, weight_zero_points);
set_argument(5, input_scales);
validate_and_infer_types();
}

FullyConnectedQuantized::FullyConnectedQuantized(const ov::Output<Node>& X,
const ov::Output<Node>& W,
const ov::Output<Node>& bias,
const ov::Output<Node>& weight_scales,
const ov::Output<Node>& weight_zero_points,
const ov::element::Type output_type)
: FullyConnected(X, W, bias, output_type) {
set_argument(3, weight_scales);
set_argument(4, weight_zero_points);
}

FullyConnectedQuantized::FullyConnectedQuantized(const ov::Output<Node>& X,
const ov::Output<Node>& W,
const ov::Output<Node>& bias,
const ov::Output<Node>& weight_scales,
const ov::element::Type output_type)
: FullyConnected(X, W, bias, output_type) {
set_argument(3, weight_scales);
}

std::shared_ptr<ov::Node> FullyConnectedQuantized::clone_with_new_inputs(const ov::OutputVector& new_args) const {
check_new_args_count(this, new_args);

return std::make_shared<FullyConnectedQuantized>(new_args, m_output_type);
}

std::shared_ptr<Node> FullyConnectedQuantized::fuse_bias(const ov::Output<Node>& bias) const {
auto inputs = input_values();
inputs[2] = bias;

return std::make_shared<FullyConnectedQuantized>(inputs, get_output_type());
return std::make_shared<FullyConnectedQuantized>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
new_args.at(5),
new_args.at(6),
new_args.at(7),
new_args.at(8),
m_output_type);
}

// @todo finalize validate_and_infer_types
void FullyConnectedQuantized::validate_and_infer_types() {
const auto input_size = get_input_size();
const size_t expected_size = 9;
NODE_VALIDATION_CHECK(this,
input_size >= 4,
input_size == expected_size,
"Number of inputs is incorrect. Current value is: ",
input_size,
", expected at least 3.");
", expected at least ",
expected_size,
".");

FullyConnected::validate_and_infer_types();
}
Expand Down
Loading

0 comments on commit 3705e87

Please sign in to comment.