Skip to content

Commit

Permalink
[CPU] Fix ScaledDotProductAttention build failure on ubuntu18 (openvi…
Browse files Browse the repository at this point in the history
  • Loading branch information
luo-cheng2021 authored Nov 17, 2023
1 parent 6bdc159 commit e4311aa
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 17 deletions.
5 changes: 0 additions & 5 deletions src/plugins/intel_cpu/src/nodes/scaled_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,6 @@ void ScaledDotProductAttention::execute(dnnl::stream strm) {
}

bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
#if defined(OPENVINO_ARCH_X86_64)
try {
if (!std::dynamic_pointer_cast<const ov::op::v13::ScaledDotProductAttention>(op)) {
errorMessage = "Only ScaledDotProductAttention operation are supported";
Expand All @@ -774,10 +773,6 @@ bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptr<const
return false;
}
return true;
#else
// current optimization is not suitable for ARM
return false;
#endif
}

} // namespace node
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
CPU_SET_CALLBACK_COMMON(manager, nmsCallback, ov::pass::ConvertNMS9ToNMSIEInternal);
CPU_SET_CALLBACK_COMMON(manager, nmsCallback, ov::pass::ConvertMulticlassNmsToMulticlassNmsIE);
CPU_SET_CALLBACK_COMMON(manager, nmsCallback, ov::pass::ConvertMatrixNmsToMatrixNmsIE);
CPU_SET_CALLBACK_COMMON(manager,
CPU_SET_CALLBACK_X64(manager,
[](const_node_ptr &node) -> bool {
std::string errorMsg;
return node::ScaledDotProductAttention::isSupportedOperation(node, errorMsg);
Expand Down Expand Up @@ -470,7 +470,6 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
CPU_DISABLE_PASS_COMMON(manager, ov::pass::ConvertTopK11ToTopK3);
CPU_DISABLE_PASS_COMMON(manager, ov::pass::HSwishDecomposition);
CPU_DISABLE_PASS_COMMON(manager, ov::pass::MatMulConstTransposesExtraction);
CPU_DISABLE_PASS_X64(manager, ov::pass::ScaledDotProductAttentionDecomposition);
CPU_DISABLE_PASS_X64(manager, ov::pass::HSigmoidDecomposition);

CPU_DISABLE_PASS_X64(manager, ov::pass::ReduceL1Decomposition);
Expand Down
16 changes: 6 additions & 10 deletions src/plugins/intel_cpu/src/utils/plain_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,16 @@ struct precision_of<float16> {
};

#define PLAINTENSOR_RANK_MAX 8
struct PlainTensorBase {

template <typename DT>
struct PlainTensor {
size_t m_strides[PLAINTENSOR_RANK_MAX];
size_t m_dims[PLAINTENSOR_RANK_MAX];
size_t m_rank = 0;
void* m_ptr = nullptr;
size_t m_capacity = 0;
bool with_storage = false;
MemoryPtr m_mem; // hold memory ptr reference

operator bool() const {
return static_cast<bool>(m_ptr);
Expand All @@ -135,13 +138,6 @@ struct PlainTensorBase {
assert(i < m_rank);
return m_strides[i];
}
virtual ov::element::Type get_precision(void) = 0;
virtual void reset(MemoryPtr mem) = 0;
};

template <typename DT>
struct PlainTensor : public PlainTensorBase {
MemoryPtr m_mem; // hold memory ptr reference
PlainTensor(MemoryPtr mem) {
reset(mem);
}
Expand Down Expand Up @@ -172,14 +168,14 @@ struct PlainTensor : public PlainTensorBase {
}
}

void reset(MemoryPtr mem) override {
void reset(MemoryPtr mem) {
assert_dt<DT>(mem->getDesc().getPrecision());
m_mem = mem;
// this reshape_to() can do reshape w/o additional cost
resize(mem->getStaticDims(), reinterpret_cast<DT*>(mem->getData()));
}

ov::element::Type get_precision(void) override {
ov::element::Type get_precision(void) {
return precision_of<DT>::value;
}

Expand Down

0 comments on commit e4311aa

Please sign in to comment.