Skip to content

Commit

Permalink
[NPUW] Introduce DQ_FULL property (#27800)
Browse files Browse the repository at this point in the history
Mirror of #27678
  • Loading branch information
smirnov-alexey authored Nov 29, 2024
1 parent d01919e commit 9039484
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ DEFINE_OPT(NPUW_PLAN, std::string, "", npuw::partitioning::plan, CompileTime);
DEFINE_OPT(NPUW_FOLD, bool, false, npuw::partitioning::fold, CompileTime);
DEFINE_OPT(NPUW_CWAI, bool, false, npuw::partitioning::cwai, CompileTime);
DEFINE_OPT(NPUW_DQ, bool, false, npuw::partitioning::dyn_quant, CompileTime);
DEFINE_OPT(NPUW_DQ_FULL, bool, true, npuw::partitioning::dyn_quant_full, CompileTime);
DEFINE_OPT(NPUW_PMM, std::string, "2", npuw::partitioning::par_matmul_merge_dims, CompileTime);
DEFINE_OPT(NPUW_SLICE_OUT, bool, false, npuw::partitioning::slice_out, CompileTime);
DEFINE_OPT(NPUW_HOST_GATHER, bool, true, npuw::partitioning::host_gather, CompileTime);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,14 @@ static constexpr ov::Property<bool> cwai{"NPUW_CWAI"};
*/
static constexpr ov::Property<bool> dyn_quant{"NPUW_DQ"};

/**
* @brief
* Type: bool.
* Apply the full DQ transformation pipeline in the plugin.
* Default value: true.
*/
static constexpr ov::Property<bool> dyn_quant_full{"NPUW_DQ_FULL"};

/**
* @brief
* Type: string.
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_npu/src/al/src/config/npuw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ void intel_npu::registerNPUWOptions(OptionsDesc& desc) {
desc.add<NPUW_FOLD>();
desc.add<NPUW_CWAI>();
desc.add<NPUW_DQ>();
desc.add<NPUW_DQ_FULL>();
desc.add<NPUW_PMM>();
desc.add<NPUW_SLICE_OUT>();
desc.add<NPUW_SPATIAL>();
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_npu/src/plugin/npuw/compiled_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,7 @@ void ov::npuw::CompiledModel::implement_properties() {
BIND(npuw::partitioning::fold, NPUW_FOLD),
BIND(npuw::partitioning::cwai, NPUW_CWAI),
BIND(npuw::partitioning::dyn_quant, NPUW_DQ),
BIND(npuw::partitioning::dyn_quant_full, NPUW_DQ_FULL),
BIND(npuw::partitioning::par_matmul_merge_dims, NPUW_PMM),
BIND(npuw::partitioning::slice_out, NPUW_SLICE_OUT),
BIND(npuw::partitioning::spatial, NPUW_SPATIAL),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1941,9 +1941,10 @@ void Partitioner::optimize(const std::string& func_name) {
// Run "dynamic quantization"
ov::npuw::patterns::opt::Context ctx;
ctx.is_spatial = f._spatial.has_value();
ctx.mm_dq_full = cfg.get<::intel_npu::NPUW_DQ_FULL>();

ov::pass::GraphRewrite rewr;
rewr.add_matcher<ov::npuw::patterns::opt::DQMatMulCWi>();
rewr.add_matcher<ov::npuw::patterns::opt::DQMatMulCWi>(std::ref(ctx));
rewr.add_matcher<ov::npuw::patterns::opt::DQMatMulGQi>(std::ref(ctx));
rewr.add_matcher<ov::npuw::patterns::opt::DQMatMulGQ2i>(std::ref(ctx));
rewr.add_matcher<ov::npuw::patterns::opt::DQMatMulGQiP>(std::ref(ctx));
Expand Down
149 changes: 125 additions & 24 deletions src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ namespace opt {

void Context::permute(PPtr orig_param, const Context::Axes& order) {
closures_to_permute[orig_param] = order;

const auto& orig_shape = orig_param->get_shape();
ov::Shape tw_shape;
for (const auto& axis : order) {
tw_shape.push_back(orig_shape[axis]);
}
orig_param->set_partial_shape(tw_shape);
orig_param->validate_and_infer_types();
}

void Context::to_f16(PPtr orig_param) {
Expand Down Expand Up @@ -126,7 +134,7 @@ namespace uat = ov::npuw::util::at;
// Param(S) -> Reshape ----------->
//

DQMatMulCWi::DQMatMulCWi() {
DQMatMulCWi::DQMatMulCWi(Context::Ref ctx) {
auto qweight = opp::wrap_type<ov::op::v0::Parameter>();
auto qcoeff = opp::wrap_type<ov::op::v0::Parameter>();
auto qcvtw = opp::wrap_type<ov::op::v0::Convert>({qweight});
Expand Down Expand Up @@ -157,6 +165,14 @@ DQMatMulCWi::DQMatMulCWi() {
auto matched_node_muls = node_to_output.at(qmuls).get_node_shared_ptr();
auto matched_node_mmi = node_to_output.at(qmmi).get_node_shared_ptr();

if (!ctx.get().mm_dq_full) {
const auto& matm_mul_out_shape = matched_matmul->get_output_shape(0);
const auto& matm_mul_in_shape = matched_matmul->get_input_shape(1);
NPUW_ASSERT(matm_mul_out_shape.back() == matm_mul_in_shape.front());
NPUW_ASSERT(matched_matmul->get_transpose_b());
return false; // root hasn't changed
}

// Reconnect MatMul to read from Convert(W) directly.
// Note: ACT is f32 so has to be converted too.
auto new_cvt_act = std::make_shared<ov::op::v0::Convert>(matched_node_mmi, ov::element::f16);
Expand Down Expand Up @@ -231,7 +247,9 @@ DQMatMulGQi::DQMatMulGQi(Context::Ref ctx) {

auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr();
auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr();
auto matched_node_qmuls = node_to_output.at(qmuls).get_node_shared_ptr();
auto matched_node_matmul = node_to_output.at(qmm).get_node_shared_ptr();
auto matched_node_qreshp = node_to_output.at(qreshp).get_node_shared_ptr();
auto matched_out_mmi = node_to_output.at(qmmi);

auto matched_qweight = std::static_pointer_cast<ov::op::v0::Parameter>(matched_node_qweight);
Expand All @@ -248,10 +266,36 @@ DQMatMulGQi::DQMatMulGQi(Context::Ref ctx) {
act_shape.size() == 3 && act_shape[1] == 1 && // single-token case
qcoeff_shape[0] == qweight_shape[0] && qcoeff_shape[1] == 1 && qcoeff_shape[2] == qweight_shape[2] &&
!matched_matmul->get_transpose_a() && !matched_matmul->get_transpose_b()) {
if (!ctx.get().mm_dq_full) {
// Transpose weight and coeff
ctx.get().permute(matched_qweight, {0, 2, 1});
ctx.get().permute(matched_qcoeff, {0, 2, 1});

// Add Transpose and insert it
std::vector<std::size_t> new_transpose_order = {1, 0, 2};
auto new_transpose_order_c =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, new_transpose_order);
auto new_transpose = std::make_shared<ov::op::v1::Transpose>(matched_node_qmuls, new_transpose_order_c);
matched_node_qreshp->input(0).replace_source_output(new_transpose);
matched_node_qreshp->validate_and_infer_types();

// Change Reshape's shape
std::vector<std::size_t> transposed_shape = {qweight_shape[2], qweight_shape[0] * qweight_shape[1]};
auto transposed_shape_c =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{2}, transposed_shape);
matched_node_qreshp->input(1).replace_source_output(transposed_shape_c);
matched_node_qreshp->validate_and_infer_types();

matched_matmul->set_transpose_b(true);
matched_matmul->validate_and_infer_types();

const auto& matm_mul_out_shape = matched_matmul->get_output_shape(0);
const auto& matm_mul_in_shape = matched_matmul->get_input_shape(1);
NPUW_ASSERT(matm_mul_out_shape.back() == matm_mul_in_shape.front());
return false; // root hasn't changed
}

// Mark W closure to transpose, and transpose the respective parameter
ov::Shape tw_shape = {qweight_shape[0], qweight_shape[2], qweight_shape[1]};
matched_qweight->set_partial_shape(tw_shape);
matched_qweight->validate_and_infer_types();
ctx.get().permute(matched_qweight, {0, 2, 1});

// Mark S closure to be lowered fo f16
Expand Down Expand Up @@ -346,7 +390,9 @@ DQMatMulGQ2i::DQMatMulGQ2i(Context::Ref ctx) {

auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr();
auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr();
auto matched_node_qmuls = node_to_output.at(qmuls).get_node_shared_ptr();
auto matched_node_matmul = node_to_output.at(qmm).get_node_shared_ptr();
auto matched_node_qreshp = node_to_output.at(qreshp).get_node_shared_ptr();
auto matched_out_mmi = node_to_output.at(qmmi);

auto matched_qweight = std::static_pointer_cast<ov::op::v0::Parameter>(matched_node_qweight);
Expand All @@ -363,20 +409,33 @@ DQMatMulGQ2i::DQMatMulGQ2i(Context::Ref ctx) {
act_shape.size() == 3 && act_shape[0] == 1 && act_shape[1] == 1 && qcoeff_shape[0] == qweight_shape[0] &&
qcoeff_shape[2] == 1 && qcoeff_shape[1] == qweight_shape[1] && !matched_matmul->get_transpose_a() &&
matched_matmul->get_transpose_b()) {
if (!ctx.get().mm_dq_full) {
// Transpose weight and coeff
ctx.get().permute(matched_qweight, {1, 0, 2});
ctx.get().permute(matched_qcoeff, {1, 0, 2});

// Add Transpose and insert it
std::vector<std::size_t> new_transpose_order = {1, 0, 2};
auto new_transpose_order_c =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, new_transpose_order);
auto new_transpose = std::make_shared<ov::op::v1::Transpose>(matched_node_qmuls, new_transpose_order_c);
matched_node_qreshp->input(0).replace_source_output(new_transpose);
matched_node_qreshp->validate_and_infer_types();
matched_matmul->validate_and_infer_types();

const auto& matm_mul_out_shape = matched_matmul->get_output_shape(0);
const auto& matm_mul_in_shape = matched_matmul->get_input_shape(1);
NPUW_ASSERT(matm_mul_out_shape.back() == matm_mul_in_shape.front());
NPUW_ASSERT(matched_matmul->get_transpose_b());
return false; // root hasn't changed
}

// Mark W closure to transpose, and transpose the respective parameter
ctx.get().permute(matched_qweight, {1, 0, 2});

ov::Shape tw_shape = {qweight_shape[1], qweight_shape[0], qweight_shape[2]};
matched_qweight->set_partial_shape(tw_shape);
matched_qweight->validate_and_infer_types();

// Also transpose S, but in a different way (see diagram above)
ctx.get().permute(matched_qcoeff, {1, 2, 0});

ov::Shape ts_shape = {qcoeff_shape[1], qcoeff_shape[2], qcoeff_shape[0]};
matched_qcoeff->set_partial_shape(ts_shape);
matched_qcoeff->validate_and_infer_types();

// Reshape the Act to group format
const auto NSPLIT = qweight_shape[1];
std::vector<std::size_t> rshp_act_v = {NSPLIT, 1, act_shape[2] / NSPLIT};
Expand Down Expand Up @@ -466,7 +525,9 @@ DQMatMulGQiP::DQMatMulGQiP(Context::Ref ctx) {

auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr();
auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr();
auto matched_node_qmuls = node_to_output.at(qmuls).get_node_shared_ptr();
auto matched_node_matmul = node_to_output.at(qmm).get_node_shared_ptr();
auto matched_node_qreshp = node_to_output.at(qreshp).get_node_shared_ptr();
auto matched_out_mmi = node_to_output.at(qmmi);

auto matched_qweight = std::static_pointer_cast<ov::op::v0::Parameter>(matched_node_qweight);
Expand All @@ -482,15 +543,39 @@ DQMatMulGQiP::DQMatMulGQiP(Context::Ref ctx) {
act_shape.size() == 3 && act_shape[1] > 1 && // multi-token case
qcoeff_shape[0] == qweight_shape[0] && qcoeff_shape[1] == 1 && qcoeff_shape[2] == qweight_shape[2] &&
!matched_matmul->get_transpose_a() && !matched_matmul->get_transpose_b()) {
if (!ctx.get().mm_dq_full) {
// Transpose weight and coeff
ctx.get().permute(matched_qweight, {0, 2, 1});
ctx.get().permute(matched_qcoeff, {0, 2, 1});

// Add Transpose and insert it
std::vector<std::size_t> new_transpose_order = {1, 0, 2};
auto new_transpose_order_c =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, new_transpose_order);
auto new_transpose = std::make_shared<ov::op::v1::Transpose>(matched_node_qmuls, new_transpose_order_c);
matched_node_qreshp->input(0).replace_source_output(new_transpose);
matched_node_qreshp->validate_and_infer_types();

// // Change Reshape's shape
std::vector<std::size_t> transposed_shape = {qweight_shape[2], qweight_shape[0] * qweight_shape[1]};
auto transposed_shape_c =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{2}, transposed_shape);
matched_node_qreshp->input(1).replace_source_output(transposed_shape_c);
matched_node_qreshp->validate_and_infer_types();

matched_matmul->set_transpose_b(true);
matched_matmul->validate_and_infer_types();

const auto& matm_mul_out_shape = matched_matmul->get_output_shape(0);
const auto& matm_mul_in_shape = matched_matmul->get_input_shape(1);
NPUW_ASSERT(matm_mul_out_shape.back() == matm_mul_in_shape.front());
return false; // root hasn't changed
}

// Mark W closure to transpose, and transpose the respective parameter
ov::Shape tw_shape = {qweight_shape[0], qweight_shape[2], qweight_shape[1]};
matched_qweight->set_partial_shape(tw_shape);
matched_qweight->validate_and_infer_types();
ctx.get().permute(matched_qweight, {0, 2, 1});

// Mark S closure to be lowered fo f16
matched_qcoeff->set_element_type(ov::element::f16);
matched_qcoeff->validate_and_infer_types();
ctx.get().to_f16(matched_qcoeff);

// Reshape the Act to group format
Expand Down Expand Up @@ -579,7 +664,9 @@ DQMatMulGQ2iP::DQMatMulGQ2iP(Context::Ref ctx) {

auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr();
auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr();
auto matched_node_qmuls = node_to_output.at(qmuls).get_node_shared_ptr();
auto matched_node_matmul = node_to_output.at(qmm).get_node_shared_ptr();
auto matched_node_qreshp = node_to_output.at(qreshp).get_node_shared_ptr();
auto matched_out_mmi = node_to_output.at(qmmi);

auto matched_qweight = std::static_pointer_cast<ov::op::v0::Parameter>(matched_node_qweight);
Expand All @@ -599,19 +686,33 @@ DQMatMulGQ2iP::DQMatMulGQ2iP(Context::Ref ctx) {
act_shape.size() == 3 && just_one(act_shape[0], act_shape[1]) && // multi-token case
qcoeff_shape[0] == qweight_shape[0] && qcoeff_shape[1] == qweight_shape[1] && qcoeff_shape[2] == 1 &&
!matched_matmul->get_transpose_a() && matched_matmul->get_transpose_b()) {
if (!ctx.get().mm_dq_full) {
// Transpose weight and coeff
ctx.get().permute(matched_qweight, {1, 0, 2});
ctx.get().permute(matched_qcoeff, {1, 0, 2});

// Add Transpose and insert it
std::vector<std::size_t> new_transpose_order = {1, 0, 2};
auto new_transpose_order_c =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, new_transpose_order);
auto new_transpose = std::make_shared<ov::op::v1::Transpose>(matched_node_qmuls, new_transpose_order_c);
matched_node_qreshp->input(0).replace_source_output(new_transpose);
matched_node_qreshp->validate_and_infer_types();
matched_matmul->validate_and_infer_types();

const auto& matm_mul_out_shape = matched_matmul->get_output_shape(0);
const auto& matm_mul_in_shape = matched_matmul->get_input_shape(1);
NPUW_ASSERT(matm_mul_out_shape.back() == matm_mul_in_shape.front());
NPUW_ASSERT(matched_matmul->get_transpose_b());
return false; // root hasn't changed
}

// Mark W closure to transpose, and transpose the respective parameter
ov::Shape tw_shape = {qweight_shape[1], qweight_shape[0], qweight_shape[2]};
matched_qweight->set_partial_shape(tw_shape);
matched_qweight->validate_and_infer_types();
ctx.get().permute(matched_qweight, {1, 0, 2});

// Also transpose S, but in a different way (see diagram above)
ctx.get().permute(matched_qcoeff, {1, 2, 0});

ov::Shape ts_shape = {qcoeff_shape[1], qcoeff_shape[2], qcoeff_shape[0]};
matched_qcoeff->set_partial_shape(ts_shape);
matched_qcoeff->validate_and_infer_types();

// Select proper activation shape
std::size_t act_dim = act_shape[0] > act_shape[1] ? 0 : 1;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,10 @@ namespace npuw {
namespace patterns {
namespace opt {

class DQMatMulCWi : public ov::pass::MatcherPass {
public:
DQMatMulCWi();
};

struct Context {
std::string pmm_dims;
bool is_spatial = false;
bool mm_dq_full = true;

using PPtr = std::shared_ptr<ov::op::v0::Parameter>;
using NPtr = std::shared_ptr<ov::Node>;
Expand Down Expand Up @@ -66,6 +62,11 @@ struct Context {
using Ref = std::reference_wrapper<Context>;
};

class DQMatMulCWi : public ov::pass::MatcherPass {
public:
explicit DQMatMulCWi(Context::Ref ctx);
};

class DQMatMulGQi : public ov::pass::MatcherPass {
public:
explicit DQMatMulGQi(Context::Ref ctx);
Expand Down
Loading

0 comments on commit 9039484

Please sign in to comment.