Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize com.microsoft.MatMulNbits operator #28504

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 75 additions & 107 deletions src/frontends/onnx/frontend/src/op/com.microsoft/matmulnbits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
#include "openvino/op/slice.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/reshape.hpp"
#include "utils/common.hpp"
#include "utils/reshape.hpp"
#include "utils/split.hpp"

using namespace ov::op;

Expand Down Expand Up @@ -111,142 +115,106 @@ ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) {
bias.get_partial_shape());
}

ov::Output<ov::Node> mm_output;
{
const auto b_const = ov::as_type_ptr<v0::Constant>(b_quantized.get_node_shared_ptr());

ov::Output<ov::Node> casted_b;
ov::Shape casted_b_shape;
ov::Output<ov::Node> default_zp;
// Casting/converting data of source constant.
// For further calculations (sub and/or multiply) we need to reshape it from [N][n_blocks_per_col][blob_size *
// X] to [N * n_blocks_per_col][blob_size * X] (where X is amount of values in 1 byte) because scale and
// zero_point are represented as: ...with shape like: [N * n_blocks_per_col]...
// For further calculations (sub and/or multiply) we need to reshape
// b -> [N][n_blocks_per_col][block_size]
switch (bits) {
case 2:
casted_b_shape = ov::Shape{static_cast<size_t>(N * n_blocks_per_col), static_cast<size_t>(blob_size * 4)};
casted_b_shape = ov::Shape{static_cast<size_t>(N),
static_cast<size_t>(n_blocks_per_col),
static_cast<size_t>(blob_size * 4)};
casted_b = std::make_shared<v0::Constant>(ov::element::u2, casted_b_shape, b_const->get_data_ptr());
if (a.get_element_type() != ov::element::dynamic) {
default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 2);
} else {
default_zp =
std::make_shared<v1::ConvertLike>(a,
std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 2.f));
}
default_zp = std::make_shared<v0::Constant>(ov::element::u2, Shape{1}, 2);
break;
case 4:
casted_b_shape = ov::Shape{static_cast<size_t>(N * n_blocks_per_col), static_cast<size_t>(blob_size * 2)};
casted_b_shape = ov::Shape{static_cast<size_t>(N),
static_cast<size_t>(n_blocks_per_col),
static_cast<size_t>(blob_size * 2)};
casted_b = std::make_shared<v0::Constant>(ov::element::u4, casted_b_shape, b_const->get_data_ptr());
if (a.get_element_type() != ov::element::dynamic) {
default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 8);
} else {
default_zp =
std::make_shared<v1::ConvertLike>(a,
std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 8.f));
}
default_zp = std::make_shared<v0::Constant>(ov::element::u4, Shape{1}, 8);
break;
case 8:
casted_b_shape = ov::Shape{static_cast<size_t>(N * n_blocks_per_col), static_cast<size_t>(blob_size)};
casted_b = op::util::reshape(b_const, casted_b_shape);
if (a.get_element_type() != ov::element::dynamic) {
default_zp = std::make_shared<v0::Constant>(a.get_element_type(), Shape{}, 128);
} else {
default_zp =
std::make_shared<v1::ConvertLike>(a,
std::make_shared<v0::Constant>(ov::element::f32, Shape{}, 128.f));
}
casted_b_shape = ov::Shape{static_cast<size_t>(N),
static_cast<size_t>(n_blocks_per_col),
static_cast<size_t>(blob_size)};
casted_b = std::make_shared<v0::Constant>(ov::element::u8, casted_b_shape, b_const->get_data_ptr());
default_zp = std::make_shared<v0::Constant>(ov::element::u8, Shape{1}, 128);
break;
default:
FRONT_END_THROW("Unsupported bits count");
break;
}

if (!zero_points.get_node_shared_ptr()) {
zero_points = default_zp;
} else {
// https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MatMulNBits
// according to the link, zero point are:
// Constrain quantized zero point types to uint8/int32/float16/float.
// Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B
zero_points =
op::util::reshape(zero_points,
ov::Shape{static_cast<size_t>(N), static_cast<size_t>(n_blocks_per_col), 1});
}

// Possible issue with slice implementation, had to move convertion before slice, instead of slicing uint4
// TODO: Ticket
const auto converted_b = std::make_shared<v1::ConvertLike>(casted_b, a);
// Comments: it is still there, so need to convert b to fp16 first.

// TODO: Need to collect performance data in case constant folding is applied. Possible some perf/mem-gap

// Simple case
if (n_blocks_per_col == 1) {
// Removing unused items in case block is bigger than column count
// For example, if data is (uint8)[1,2,3,4,5,6] then block will be (uint8)[1,2,3,4,5,6,0,0,0,0,0,0,0,0,0,0].
// And last zeros are unused.
const auto zero_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 0);
const auto one_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 1);
const auto elements_const =
std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, static_cast<int32_t>(K));
const auto axis_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 1);
const auto slice_b =
std::make_shared<v8::Slice>(converted_b, zero_const, elements_const, one_const, axis_const);

// Transpose matrix
const auto transposed_shape =
std::make_shared<v0::Constant>(ov::element::i64, Shape{2}, std::vector<int64_t>{1, 0});
const auto transposed_b = std::make_shared<v1::Transpose>(slice_b, transposed_shape);

// If no zero-points provided - we generate default, depends on data size
if (!zero_points.get_node_shared_ptr()) {
zero_points = default_zp;
}
const auto sub_b = std::make_shared<v1::Subtract>(transposed_b, zero_points);

// Scaling
const auto scaled_b = std::make_shared<v1::Multiply>(sub_b, scales);

// Adding bias if required
if (!bias.get_node_shared_ptr()) {
b = scaled_b;
} else {
b = std::make_shared<v1::Add>(scaled_b, bias);
}
// Comments: in this latest code, the const folding is gone, it trigle the oneDNN kernel
// and use u2/u4/u8 weights as the kernel's input, won't do const folding anymore.

// use fp16 for compute

// convert b to fp16
auto converted_b = std::make_shared<v0::Convert>(casted_b, ov::element::f16);
auto converted_zero_points = std::make_shared<v0::Convert>(zero_points, ov::element::f16);

// sub and scale
const auto sub_b = std::make_shared<v1::Subtract>(converted_b, converted_zero_points);
const auto scales_fp16 = std::make_shared<v0::Convert>(scales, ov::element::f16);
const auto scales_reshaped =
op::util::reshape(scales_fp16,
ov::Shape{static_cast<size_t>(N), static_cast<size_t>(n_blocks_per_col), 1});
const auto scaled_b = std::make_shared<v1::Multiply>(sub_b, scales_reshaped);

// reshape b to [N, K]
auto shape_b = v0::Constant::create(ov::element::i32, ov::Shape{2}, {0, -1});
auto reshaped_b = std::make_shared<v1::Reshape>(scaled_b, shape_b, true);

// if n_blocks_per_col*blob_size*X != K
// need slice it to K
// to produce b = [N, K]
const bool slice_needed = (K % block_size != 0);
if (slice_needed) {
const auto zero = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 0);
const auto one = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 1);
const auto elements = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, static_cast<int32_t>(K));
const auto axis = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 1);
b = std::make_shared<v8::Slice>(reshaped_b, zero, elements, one, axis);
} else {
// Transpose matrix. Quantized B matrix is transposed and has a shape [N,K].
// To apply further operations on it which operand's shape is [N] we do this
// transpose to have a matrix [K,N]...
const auto transposed_shape =
std::make_shared<v0::Constant>(ov::element::i64, Shape{2}, std::vector<int64_t>{1, 0});
ov::Output<ov::Node> transposed_b = std::make_shared<v1::Transpose>(converted_b, transposed_shape);

// If no zero-points provided - we generate default, depends on data size
if (!zero_points.get_node_shared_ptr()) {
zero_points = default_zp;
}
const auto sub_b = std::make_shared<v1::Subtract>(transposed_b, zero_points);

// Scaling
const auto scaled_b = std::make_shared<v1::Multiply>(sub_b, scales);

// Transpose again to make reshaping and slicing
transposed_b = std::make_shared<v1::Transpose>(scaled_b, transposed_shape);

const auto reshaped_b =
op::util::reshape(transposed_b,
ov::Shape{static_cast<size_t>(casted_b_shape[0] / n_blocks_per_col),
static_cast<size_t>(casted_b_shape[1] * n_blocks_per_col)});

// Removing unused items in case block is bigger than column count (see description for
// Slice above)
const auto zero_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 0);
const auto one_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 1);
const auto elements_const =
std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, static_cast<int32_t>(K));
const auto axis_const = std::make_shared<v0::Constant>(ov::element::i32, Shape{1}, 1);
const auto slice_b =
std::make_shared<v8::Slice>(reshaped_b, zero_const, elements_const, one_const, axis_const);

// Adding bias if required
if (!bias.get_node_shared_ptr()) {
return {std::make_shared<v0::MatMul>(a, slice_b, false, true)};
} else {
// Transpose again
transposed_b = std::make_shared<v1::Transpose>(slice_b, transposed_shape);

b = std::make_shared<v1::Add>(transposed_b, bias);
}
b = reshaped_b;
}

// mm = matmul(a,b)
auto a_fp16 = std::make_shared<v0::Convert>(a, ov::element::f16);
auto results = std::make_shared<v0::MatMul>(a_fp16, b, false, true);
mm_output = std::make_shared<v0::Convert>(results, a.get_element_type());
}

return {std::make_shared<v0::MatMul>(a, b)};
if (bias.get_node_shared_ptr()) {
return {std::make_shared<v1::Add>(mm_output, bias)};
} else {
return {mm_output};
}
}

ONNX_OP("MatMulNBits", OPSET_SINCE(1), com_microsoft::opset_1::matmulnbits, MICROSOFT_DOMAIN);
Expand Down
Loading