diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index bf02537e5f191a..36a395313079a3 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -753,7 +753,6 @@ void ScaledDotProductAttention::execute(dnnl::stream strm) { } bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { -#if defined(OPENVINO_ARCH_X86_64) try { if (!std::dynamic_pointer_cast(op)) { errorMessage = "Only ScaledDotProductAttention operation are supported"; @@ -774,10 +773,6 @@ bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptr& defaultPrecis CPU_SET_CALLBACK_COMMON(manager, nmsCallback, ov::pass::ConvertNMS9ToNMSIEInternal); CPU_SET_CALLBACK_COMMON(manager, nmsCallback, ov::pass::ConvertMulticlassNmsToMulticlassNmsIE); CPU_SET_CALLBACK_COMMON(manager, nmsCallback, ov::pass::ConvertMatrixNmsToMatrixNmsIE); - CPU_SET_CALLBACK_COMMON(manager, + CPU_SET_CALLBACK_X64(manager, [](const_node_ptr &node) -> bool { std::string errorMsg; return node::ScaledDotProductAttention::isSupportedOperation(node, errorMsg); @@ -470,7 +470,6 @@ void Transformations::PreLpt(const std::vector& defaultPrecis CPU_DISABLE_PASS_COMMON(manager, ov::pass::ConvertTopK11ToTopK3); CPU_DISABLE_PASS_COMMON(manager, ov::pass::HSwishDecomposition); CPU_DISABLE_PASS_COMMON(manager, ov::pass::MatMulConstTransposesExtraction); - CPU_DISABLE_PASS_X64(manager, ov::pass::ScaledDotProductAttentionDecomposition); CPU_DISABLE_PASS_X64(manager, ov::pass::HSigmoidDecomposition); CPU_DISABLE_PASS_X64(manager, ov::pass::ReduceL1Decomposition); diff --git a/src/plugins/intel_cpu/src/utils/plain_tensor.hpp b/src/plugins/intel_cpu/src/utils/plain_tensor.hpp index 0e7b51833ccf3a..c98b16a4067684 100644 --- a/src/plugins/intel_cpu/src/utils/plain_tensor.hpp +++ b/src/plugins/intel_cpu/src/utils/plain_tensor.hpp @@ -109,13 +109,16 @@ struct precision_of { }; #define PLAINTENSOR_RANK_MAX 8 -struct PlainTensorBase { + +template +struct PlainTensor { size_t m_strides[PLAINTENSOR_RANK_MAX]; size_t m_dims[PLAINTENSOR_RANK_MAX]; size_t m_rank = 0; void* m_ptr = nullptr; size_t m_capacity = 0; bool with_storage = false; + MemoryPtr m_mem; // hold memory ptr reference operator bool() const { return static_cast(m_ptr); @@ -135,13 +138,6 @@ struct PlainTensorBase { assert(i < m_rank); return m_strides[i]; } - virtual ov::element::Type get_precision(void) = 0; - virtual void reset(MemoryPtr mem) = 0; -}; - -template -struct PlainTensor : public PlainTensorBase { - MemoryPtr m_mem; // hold memory ptr reference PlainTensor(MemoryPtr mem) { reset(mem); } @@ -172,14 +168,14 @@ struct PlainTensor : public PlainTensorBase { } } - void reset(MemoryPtr mem) override { + void reset(MemoryPtr mem) { assert_dt
(mem->getDesc().getPrecision()); m_mem = mem; // this reshape_to() can do reshape w/o additional cost resize(mem->getStaticDims(), reinterpret_cast(mem->getData())); } - ov::element::Type get_precision(void) override { + ov::element::Type get_precision(void) { return precision_of
::value; }