Skip to content

Commit

Permalink
[CPU] Refactor Convolution node using new executor architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
EgorDuplensky committed Jan 17, 2025
1 parent 93b2567 commit c0562d9
Show file tree
Hide file tree
Showing 45 changed files with 2,141 additions and 1,773 deletions.
329 changes: 308 additions & 21 deletions src/plugins/intel_cpu/src/dnnl_postops_composer.cpp

Large diffs are not rendered by default.

21 changes: 18 additions & 3 deletions src/plugins/intel_cpu/src/dnnl_postops_composer.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ class DnnlPostOpsComposer {
const bool isINT8,
const int weiScaleMaskPerChannel,
const MemoryArgs& memory,
const dnnl::memory::data_type outDataType);
const dnnl::memory::data_type outDataType,
const std::vector<float>& legacyDqScales = {},
bool useLegacyPostOps = false,
bool useLegacyZeroPoints = false);
DnnlPrimitiveAttrs compose();

void appendDecompressionScales(const MemoryCPtr& scales_ptr,
bool needTranspose,
ov::element::Type dstPrecision,
Expand All @@ -54,8 +56,12 @@ class DnnlPostOpsComposer {
bool isLastPostOp,
bool doRounding,
bool allowBinary = true);
void appendAttrPostOpsLegacy(const ActivationPostOp& postOp);
void appendAttrPostOpsLegacy(const ScaleShiftPostOp& postOp);
void appendAttrPostOpsLegacy(const FakeQuantizePostOp& postOp);
void appendBinary(const dnnl::algorithm alg, const std::vector<float>& data);
void appendEltwise(const dnnl::algorithm alg, float alpha, float beta);
void appendSum(float scale, int32_t zeroPoint);
void appendRoundHTE();
bool appendScale(const std::vector<float>& scale, bool isLastPostOp, bool allowBinary = true);
bool appendShift(const std::vector<float>& shift, bool allowBinary = true);
Expand All @@ -64,7 +70,14 @@ class DnnlPostOpsComposer {
bool isLastPostOp,
bool allowBinary = true);
void appendClip(const std::vector<float>& low, const std::vector<float>& high);

void appendDepthwiseConvolution(int inH,
int inW,
int kerH,
int kerW,
int strH,
int strW,
dnnl::memory::data_type inDataType);
void appendZeroPoints(const MemoryArgs& memory, bool legacy);
const dnnl::engine& engine;
const PostOps& postOps;
const VectorDims outputDims;
Expand All @@ -73,6 +86,8 @@ class DnnlPostOpsComposer {
const int weightScaleMaskPerChannel;
bool weightScaleAvailable = false;
const dnnl::memory::data_type outDataType;
bool useLegacyPostOps;
bool useLegacyZeroPoints;

dnnl::primitive_attr attr;
MemoryArgs cpuArgs;
Expand Down
14 changes: 10 additions & 4 deletions src/plugins/intel_cpu/src/graph_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ void GraphOptimizer::ApplyCommonGraphOptimizations(Graph& graph) {
FuseConvolutionAndDWConvolution(graph);
graph.RemoveDroppedNodes();

OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "FuseConvolutionSumAndConvolutionSumActivation");
FuseConvolutionSumAndConvolutionSumActivation(graph);
graph.RemoveDroppedNodes();
// OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "FuseConvolutionSumAndConvolutionSumActivation");
// FuseConvolutionSumAndConvolutionSumActivation(graph);
// graph.RemoveDroppedNodes();

OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "FuseConvolutionAndSimpleOperation");
FuseConvolutionAndSimpleOperation(graph);
Expand Down Expand Up @@ -959,6 +959,7 @@ void GraphOptimizer::FuseConvolutionAndZeroPoints(Graph& graph) {
dataEltwise->getName(),
" is optimized as zeropoint of Conv ##",
conv->getName());
conv->setOriginalInputPrecisionAtPort(0, dataEltwise->getOriginalInputPrecisionAtPort(0));
graph.RemoveEdge(p_edge);
graph.DropNode(dataEltwise);
initializeOutputCompensation(conv);
Expand Down Expand Up @@ -1174,8 +1175,13 @@ void GraphOptimizer::FuseConvolutionAndDWConvolution(Graph& graph) {
if (parentConvolutionNode == nullptr)
OPENVINO_THROW("Cannot get convolution node ", parentNode->getName());

if (!impl::cpu::x64::mayiuse(impl::cpu::x64::avx2) || impl::cpu::x64::mayiuse(impl::cpu::x64::avx512_core))
if (!impl::cpu::x64::mayiuse(impl::cpu::x64::avx2))
return false;
// there is no optimized implementation for avx512, so two avx512 convolutions
// are expected to be faster than single fused avx2 convolution
if (impl::cpu::x64::mayiuse(impl::cpu::x64::avx512_core)) {
return false;
}

return (dw_conv_input_size + dw_conv_output_size > L3_cache_size / 2);
};
Expand Down
13 changes: 13 additions & 0 deletions src/plugins/intel_cpu/src/memory_format_filter.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#pragma once

#include <oneapi/dnnl/dnnl.hpp>
#include <vector>

struct MemoryFormatFilter {
std::vector<dnnl::memory::format_tag> input;
std::vector<dnnl::memory::format_tag> output;

bool empty() const {
return input.empty() && output.empty();
}
};
22 changes: 11 additions & 11 deletions src/plugins/intel_cpu/src/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ Node::Node(const std::shared_ptr<ov::Node>& op, GraphContext::CPtr ctx, const Sh
while (getline(stream, str, ',')) {
if (str.substr(0, 4) != "cpu:")
continue;
inputMemoryFormatsFilter.push_back(dnnl::utils::str2fmt(str.substr(4, str.size()).c_str()));
memoryFormatFilter.input.push_back(dnnl::utils::str2fmt(str.substr(4, str.size()).c_str()));
}
}

Expand All @@ -158,7 +158,7 @@ Node::Node(const std::shared_ptr<ov::Node>& op, GraphContext::CPtr ctx, const Sh
while (getline(stream, str, ',')) {
if (str.substr(0, 4) != "cpu:")
continue;
outputMemoryFormatsFilter.push_back(dnnl::utils::str2fmt(str.substr(4, str.size()).c_str()));
memoryFormatFilter.output.push_back(dnnl::utils::str2fmt(str.substr(4, str.size()).c_str()));
}
}

Expand Down Expand Up @@ -938,7 +938,7 @@ void Node::initSupportedPrimitiveDescriptors() {
}

void Node::filterSupportedPrimitiveDescriptors() {
if (inputMemoryFormatsFilter.empty() && outputMemoryFormatsFilter.empty())
if (memoryFormatFilter.empty())
return;

// Compare by format tag
Expand All @@ -950,26 +950,26 @@ void Node::filterSupportedPrimitiveDescriptors() {

auto isNotSuitableDesc = [&](const NodeDesc& desc) {
const auto& config = desc.getConfig();
if (inputMemoryFormatsFilter.size() > config.inConfs.size() ||
outputMemoryFormatsFilter.size() > config.outConfs.size())
if (memoryFormatFilter.input.size() > config.inConfs.size() ||
memoryFormatFilter.output.size() > config.outConfs.size())
OPENVINO_THROW("Incorrect number of input or output memory formats");

for (size_t i = 0; i < inputMemoryFormatsFilter.size(); i++) {
if (!areCompatible(*config.inConfs[i].getMemDesc(), inputMemoryFormatsFilter[i])) {
for (size_t i = 0; i < memoryFormatFilter.input.size(); i++) {
if (!areCompatible(*config.inConfs[i].getMemDesc(), memoryFormatFilter.input[i])) {
DEBUG_LOG(getName(),
" input memory format filter: ",
inputMemoryFormatsFilter[i],
memoryFormatFilter.input[i],
" not matched. Erase desc from supported primitive descriptors: ",
desc);
return true;
}
}

for (size_t i = 0; i < outputMemoryFormatsFilter.size(); i++) {
if (!areCompatible(*config.outConfs[i].getMemDesc(), outputMemoryFormatsFilter[i])) {
for (size_t i = 0; i < memoryFormatFilter.output.size(); i++) {
if (!areCompatible(*config.outConfs[i].getMemDesc(), memoryFormatFilter.output[i])) {
DEBUG_LOG(getName(),
" Output memory format filter: ",
outputMemoryFormatsFilter[i],
memoryFormatFilter.output[i],
" not matched. Erase desc from supported primitive descriptors: ",
desc);
return true;
Expand Down
4 changes: 2 additions & 2 deletions src/plugins/intel_cpu/src/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "graph_context.h"
#include "memory_desc/cpu_memory_desc.h"
#include "memory_desc/dnnl_memory_desc.h"
#include "memory_format_filter.hpp"
#include "nodes/executors/executor.hpp"
#include "nodes/node_config.h"
#include "onednn/dnnl.h"
Expand Down Expand Up @@ -713,8 +714,7 @@ class Node {

std::string primitivesPriority;
std::vector<impl_desc_type> customImplPriorities;
std::vector<dnnl::memory::format_tag> inputMemoryFormatsFilter;
std::vector<dnnl::memory::format_tag> outputMemoryFormatsFilter;
MemoryFormatFilter memoryFormatFilter;
bool enforceBF16evenForGraphTail = false;
bool keepOriginalPrecision = false;

Expand Down
Loading

0 comments on commit c0562d9

Please sign in to comment.