Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NPUW] Introduce DQ_FULL property #27678

Merged
merged 6 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ DEFINE_OPT(NPUW_PLAN, std::string, "", npuw::partitioning::plan, CompileTime);
DEFINE_OPT(NPUW_FOLD, bool, false, npuw::partitioning::fold, CompileTime);
DEFINE_OPT(NPUW_CWAI, bool, false, npuw::partitioning::cwai, CompileTime);
DEFINE_OPT(NPUW_DQ, bool, false, npuw::partitioning::dyn_quant, CompileTime);
DEFINE_OPT(NPUW_DQ_FULL, bool, true, npuw::partitioning::dyn_quant_full, CompileTime);
DEFINE_OPT(NPUW_PMM, std::string, "2", npuw::partitioning::par_matmul_merge_dims, CompileTime);
DEFINE_OPT(NPUW_SLICE_OUT, bool, false, npuw::partitioning::slice_out, CompileTime);
DEFINE_OPT(NPUW_HOST_GATHER, bool, true, npuw::partitioning::host_gather, CompileTime);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,14 @@ static constexpr ov::Property<bool> cwai{"NPUW_CWAI"};
*/
static constexpr ov::Property<bool> dyn_quant{"NPUW_DQ"};

/**
* @brief
* Type: bool.
* Apply the full DQ transformation pipeline in the plugin.
* Default value: true.
*/
static constexpr ov::Property<bool> dyn_quant_full{"NPUW_DQ_FULL"};

/**
* @brief
* Type: string.
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_npu/src/al/src/config/npuw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ void intel_npu::registerNPUWOptions(OptionsDesc& desc) {
desc.add<NPUW_FOLD>();
desc.add<NPUW_CWAI>();
desc.add<NPUW_DQ>();
desc.add<NPUW_DQ_FULL>();
desc.add<NPUW_PMM>();
desc.add<NPUW_SLICE_OUT>();
desc.add<NPUW_SPATIAL>();
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_npu/src/plugin/npuw/compiled_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,7 @@ void ov::npuw::CompiledModel::implement_properties() {
BIND(npuw::partitioning::fold, NPUW_FOLD),
BIND(npuw::partitioning::cwai, NPUW_CWAI),
BIND(npuw::partitioning::dyn_quant, NPUW_DQ),
BIND(npuw::partitioning::dyn_quant_full, NPUW_DQ_FULL),
BIND(npuw::partitioning::par_matmul_merge_dims, NPUW_PMM),
BIND(npuw::partitioning::slice_out, NPUW_SLICE_OUT),
BIND(npuw::partitioning::spatial, NPUW_SPATIAL),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1956,9 +1956,10 @@ void Partitioner::optimize(const std::string& func_name) {
// Run "dynamic quantization"
ov::npuw::patterns::opt::Context ctx;
ctx.is_spatial = f._spatial.has_value();
ctx.mm_dq_full = cfg.get<::intel_npu::NPUW_DQ_FULL>();

ov::pass::GraphRewrite rewr;
rewr.add_matcher<ov::npuw::patterns::opt::DQMatMulCWi>();
rewr.add_matcher<ov::npuw::patterns::opt::DQMatMulCWi>(std::ref(ctx));
rewr.add_matcher<ov::npuw::patterns::opt::DQMatMulGQi>(std::ref(ctx));
rewr.add_matcher<ov::npuw::patterns::opt::DQMatMulGQ2i>(std::ref(ctx));
rewr.add_matcher<ov::npuw::patterns::opt::DQMatMulGQiP>(std::ref(ctx));
Expand Down
149 changes: 125 additions & 24 deletions src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ namespace opt {

void Context::permute(PPtr orig_param, const Context::Axes& order) {
closures_to_permute[orig_param] = order;

const auto& orig_shape = orig_param->get_shape();
ov::Shape tw_shape;
for (const auto& axis : order) {
tw_shape.push_back(orig_shape[axis]);
}
orig_param->set_partial_shape(tw_shape);
orig_param->validate_and_infer_types();
}

void Context::to_f16(PPtr orig_param) {
Expand Down Expand Up @@ -126,7 +134,7 @@ namespace uat = ov::npuw::util::at;
// Param/Const(S) -> (Reshape) -> (to(f32)) -> Reshape -->
//

DQMatMulCWi::DQMatMulCWi() {
DQMatMulCWi::DQMatMulCWi(Context::Ref ctx) {
auto qweight = opp::wrap_type<ov::op::v0::Parameter>();
auto qcoeff = opp::any_input();
auto reshapew = opp::optional<ov::op::v1::Reshape>({qweight, opp::any_input()});
Expand Down Expand Up @@ -161,6 +169,14 @@ DQMatMulCWi::DQMatMulCWi() {
auto matched_node_qcoeff_out = uat::_(node_to_output).at_or_at_or_at(qcvtc, reshapec, qcoeff);
auto matched_node_muls_out = uat::_(node_to_output).at_or_at(qcvtm, qmuls);

if (!ctx.get().mm_dq_full) {
const auto& matm_mul_out_shape = matched_matmul->get_output_shape(0);
const auto& matm_mul_in_shape = matched_matmul->get_input_shape(1);
NPUW_ASSERT(matm_mul_out_shape.back() == matm_mul_in_shape.front());
NPUW_ASSERT(matched_matmul->get_transpose_b());
dmatveev marked this conversation as resolved.
Show resolved Hide resolved
return false; // root hasn't changed
}

// Reconnect MatMul to read from Convert(W) directly.
// Note: ACT has to be converted too.
auto cvt_prec = matched_node_cvtw->output(0).get_element_type();
Expand Down Expand Up @@ -238,7 +254,9 @@ DQMatMulGQi::DQMatMulGQi(Context::Ref ctx) {

auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr();
auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr();
auto matched_node_qmuls = node_to_output.at(qmuls).get_node_shared_ptr();
auto matched_node_matmul = node_to_output.at(qmm).get_node_shared_ptr();
auto matched_node_qreshp = node_to_output.at(qreshp).get_node_shared_ptr();
auto matched_out_mmi = node_to_output.at(qmmi);

auto matched_qweight = std::static_pointer_cast<ov::op::v0::Parameter>(matched_node_qweight);
Expand All @@ -255,10 +273,36 @@ DQMatMulGQi::DQMatMulGQi(Context::Ref ctx) {
act_shape.size() == 3 && act_shape[1] == 1 && // single-token case
qcoeff_shape[0] == qweight_shape[0] && qcoeff_shape[1] == 1 && qcoeff_shape[2] == qweight_shape[2] &&
!matched_matmul->get_transpose_a() && !matched_matmul->get_transpose_b()) {
if (!ctx.get().mm_dq_full) {
// Transpose weight and coeff
ctx.get().permute(matched_qweight, {0, 2, 1});
ctx.get().permute(matched_qcoeff, {0, 2, 1});

// Add Transpose and insert it
std::vector<std::size_t> new_transpose_order = {1, 0, 2};
auto new_transpose_order_c =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, new_transpose_order);
auto new_transpose = std::make_shared<ov::op::v1::Transpose>(matched_node_qmuls, new_transpose_order_c);
matched_node_qreshp->input(0).replace_source_output(new_transpose);
matched_node_qreshp->validate_and_infer_types();

// Change Reshape's shape
std::vector<std::size_t> transposed_shape = {qweight_shape[2], qweight_shape[0] * qweight_shape[1]};
auto transposed_shape_c =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{2}, transposed_shape);
matched_node_qreshp->input(1).replace_source_output(transposed_shape_c);
matched_node_qreshp->validate_and_infer_types();

matched_matmul->set_transpose_b(true);
matched_matmul->validate_and_infer_types();

const auto& matm_mul_out_shape = matched_matmul->get_output_shape(0);
const auto& matm_mul_in_shape = matched_matmul->get_input_shape(1);
NPUW_ASSERT(matm_mul_out_shape.back() == matm_mul_in_shape.front());
return false; // root hasn't changed
}

// Mark W closure to transpose, and transpose the respective parameter
ov::Shape tw_shape = {qweight_shape[0], qweight_shape[2], qweight_shape[1]};
matched_qweight->set_partial_shape(tw_shape);
matched_qweight->validate_and_infer_types();
ctx.get().permute(matched_qweight, {0, 2, 1});

// Mark S closure to be lowered fo f16
Expand Down Expand Up @@ -353,7 +397,9 @@ DQMatMulGQ2i::DQMatMulGQ2i(Context::Ref ctx) {

auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr();
auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr();
auto matched_node_qmuls = node_to_output.at(qmuls).get_node_shared_ptr();
auto matched_node_matmul = node_to_output.at(qmm).get_node_shared_ptr();
auto matched_node_qreshp = node_to_output.at(qreshp).get_node_shared_ptr();
auto matched_out_mmi = node_to_output.at(qmmi);

auto matched_qweight = std::static_pointer_cast<ov::op::v0::Parameter>(matched_node_qweight);
Expand All @@ -370,20 +416,33 @@ DQMatMulGQ2i::DQMatMulGQ2i(Context::Ref ctx) {
act_shape.size() == 3 && act_shape[0] == 1 && act_shape[1] == 1 && qcoeff_shape[0] == qweight_shape[0] &&
qcoeff_shape[2] == 1 && qcoeff_shape[1] == qweight_shape[1] && !matched_matmul->get_transpose_a() &&
matched_matmul->get_transpose_b()) {
if (!ctx.get().mm_dq_full) {
// Transpose weight and coeff
ctx.get().permute(matched_qweight, {1, 0, 2});
ctx.get().permute(matched_qcoeff, {1, 0, 2});

// Add Transpose and insert it
std::vector<std::size_t> new_transpose_order = {1, 0, 2};
auto new_transpose_order_c =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, new_transpose_order);
auto new_transpose = std::make_shared<ov::op::v1::Transpose>(matched_node_qmuls, new_transpose_order_c);
matched_node_qreshp->input(0).replace_source_output(new_transpose);
matched_node_qreshp->validate_and_infer_types();
matched_matmul->validate_and_infer_types();

const auto& matm_mul_out_shape = matched_matmul->get_output_shape(0);
const auto& matm_mul_in_shape = matched_matmul->get_input_shape(1);
NPUW_ASSERT(matm_mul_out_shape.back() == matm_mul_in_shape.front());
NPUW_ASSERT(matched_matmul->get_transpose_b());
return false; // root hasn't changed
}

// Mark W closure to transpose, and transpose the respective parameter
ctx.get().permute(matched_qweight, {1, 0, 2});

ov::Shape tw_shape = {qweight_shape[1], qweight_shape[0], qweight_shape[2]};
matched_qweight->set_partial_shape(tw_shape);
matched_qweight->validate_and_infer_types();

// Also transpose S, but in a different way (see diagram above)
ctx.get().permute(matched_qcoeff, {1, 2, 0});

ov::Shape ts_shape = {qcoeff_shape[1], qcoeff_shape[2], qcoeff_shape[0]};
matched_qcoeff->set_partial_shape(ts_shape);
matched_qcoeff->validate_and_infer_types();

// Reshape the Act to group format
const auto NSPLIT = qweight_shape[1];
std::vector<std::size_t> rshp_act_v = {NSPLIT, 1, act_shape[2] / NSPLIT};
Expand Down Expand Up @@ -473,7 +532,9 @@ DQMatMulGQiP::DQMatMulGQiP(Context::Ref ctx) {

auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr();
auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr();
auto matched_node_qmuls = node_to_output.at(qmuls).get_node_shared_ptr();
auto matched_node_matmul = node_to_output.at(qmm).get_node_shared_ptr();
auto matched_node_qreshp = node_to_output.at(qreshp).get_node_shared_ptr();
auto matched_out_mmi = node_to_output.at(qmmi);

auto matched_qweight = std::static_pointer_cast<ov::op::v0::Parameter>(matched_node_qweight);
Expand All @@ -489,15 +550,39 @@ DQMatMulGQiP::DQMatMulGQiP(Context::Ref ctx) {
act_shape.size() == 3 && act_shape[1] > 1 && // multi-token case
qcoeff_shape[0] == qweight_shape[0] && qcoeff_shape[1] == 1 && qcoeff_shape[2] == qweight_shape[2] &&
!matched_matmul->get_transpose_a() && !matched_matmul->get_transpose_b()) {
if (!ctx.get().mm_dq_full) {
// Transpose weight and coeff
ctx.get().permute(matched_qweight, {0, 2, 1});
ctx.get().permute(matched_qcoeff, {0, 2, 1});

// Add Transpose and insert it
std::vector<std::size_t> new_transpose_order = {1, 0, 2};
auto new_transpose_order_c =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, new_transpose_order);
auto new_transpose = std::make_shared<ov::op::v1::Transpose>(matched_node_qmuls, new_transpose_order_c);
matched_node_qreshp->input(0).replace_source_output(new_transpose);
matched_node_qreshp->validate_and_infer_types();

// // Change Reshape's shape
std::vector<std::size_t> transposed_shape = {qweight_shape[2], qweight_shape[0] * qweight_shape[1]};
auto transposed_shape_c =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{2}, transposed_shape);
matched_node_qreshp->input(1).replace_source_output(transposed_shape_c);
matched_node_qreshp->validate_and_infer_types();

matched_matmul->set_transpose_b(true);
matched_matmul->validate_and_infer_types();

const auto& matm_mul_out_shape = matched_matmul->get_output_shape(0);
const auto& matm_mul_in_shape = matched_matmul->get_input_shape(1);
NPUW_ASSERT(matm_mul_out_shape.back() == matm_mul_in_shape.front());
return false; // root hasn't changed
}

// Mark W closure to transpose, and transpose the respective parameter
ov::Shape tw_shape = {qweight_shape[0], qweight_shape[2], qweight_shape[1]};
matched_qweight->set_partial_shape(tw_shape);
matched_qweight->validate_and_infer_types();
ctx.get().permute(matched_qweight, {0, 2, 1});

// Mark S closure to be lowered fo f16
matched_qcoeff->set_element_type(ov::element::f16);
matched_qcoeff->validate_and_infer_types();
ctx.get().to_f16(matched_qcoeff);

// Reshape the Act to group format
Expand Down Expand Up @@ -586,7 +671,9 @@ DQMatMulGQ2iP::DQMatMulGQ2iP(Context::Ref ctx) {

auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr();
auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr();
auto matched_node_qmuls = node_to_output.at(qmuls).get_node_shared_ptr();
auto matched_node_matmul = node_to_output.at(qmm).get_node_shared_ptr();
auto matched_node_qreshp = node_to_output.at(qreshp).get_node_shared_ptr();
auto matched_out_mmi = node_to_output.at(qmmi);

auto matched_qweight = std::static_pointer_cast<ov::op::v0::Parameter>(matched_node_qweight);
Expand All @@ -606,19 +693,33 @@ DQMatMulGQ2iP::DQMatMulGQ2iP(Context::Ref ctx) {
act_shape.size() == 3 && just_one(act_shape[0], act_shape[1]) && // multi-token case
qcoeff_shape[0] == qweight_shape[0] && qcoeff_shape[1] == qweight_shape[1] && qcoeff_shape[2] == 1 &&
!matched_matmul->get_transpose_a() && matched_matmul->get_transpose_b()) {
if (!ctx.get().mm_dq_full) {
// Transpose weight and coeff
ctx.get().permute(matched_qweight, {1, 0, 2});
ctx.get().permute(matched_qcoeff, {1, 0, 2});

// Add Transpose and insert it
std::vector<std::size_t> new_transpose_order = {1, 0, 2};
auto new_transpose_order_c =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, new_transpose_order);
auto new_transpose = std::make_shared<ov::op::v1::Transpose>(matched_node_qmuls, new_transpose_order_c);
matched_node_qreshp->input(0).replace_source_output(new_transpose);
matched_node_qreshp->validate_and_infer_types();
matched_matmul->validate_and_infer_types();

const auto& matm_mul_out_shape = matched_matmul->get_output_shape(0);
const auto& matm_mul_in_shape = matched_matmul->get_input_shape(1);
NPUW_ASSERT(matm_mul_out_shape.back() == matm_mul_in_shape.front());
NPUW_ASSERT(matched_matmul->get_transpose_b());
return false; // root hasn't changed
}

// Mark W closure to transpose, and transpose the respective parameter
ov::Shape tw_shape = {qweight_shape[1], qweight_shape[0], qweight_shape[2]};
matched_qweight->set_partial_shape(tw_shape);
matched_qweight->validate_and_infer_types();
ctx.get().permute(matched_qweight, {1, 0, 2});

// Also transpose S, but in a different way (see diagram above)
ctx.get().permute(matched_qcoeff, {1, 2, 0});

ov::Shape ts_shape = {qcoeff_shape[1], qcoeff_shape[2], qcoeff_shape[0]};
matched_qcoeff->set_partial_shape(ts_shape);
matched_qcoeff->validate_and_infer_types();

// Select proper activation shape
std::size_t act_dim = act_shape[0] > act_shape[1] ? 0 : 1;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,10 @@ namespace npuw {
namespace patterns {
namespace opt {

class DQMatMulCWi : public ov::pass::MatcherPass {
public:
DQMatMulCWi();
};

struct Context {
std::string pmm_dims;
bool is_spatial = false;
bool mm_dq_full = true;

using PPtr = std::shared_ptr<ov::op::v0::Parameter>;
using NPtr = std::shared_ptr<ov::Node>;
Expand Down Expand Up @@ -66,6 +62,11 @@ struct Context {
using Ref = std::reference_wrapper<Context>;
};

class DQMatMulCWi : public ov::pass::MatcherPass {
public:
explicit DQMatMulCWi(Context::Ref ctx);
};

class DQMatMulGQi : public ov::pass::MatcherPass {
public:
explicit DQMatMulGQi(Context::Ref ctx);
Expand Down
Loading
Loading