diff --git a/src/plugins/intel_cpu/src/nodes/eltwise.cpp b/src/plugins/intel_cpu/src/nodes/eltwise.cpp index c13f22b0d9b76a..9e6b8ab890adda 100644 --- a/src/plugins/intel_cpu/src/nodes/eltwise.cpp +++ b/src/plugins/intel_cpu/src/nodes/eltwise.cpp @@ -2895,15 +2895,9 @@ void Eltwise::prepareParams() { // FP32 constant inputs may contain values out of BF16 representable range. In case output precision is BF16 we // choose "saturation" mode for fp32->bf16 conversion procedure to prevent getting -Inf/+Inf values in the - // outputs. Since "saturation" conversion is more time consuming, better solution would be to clamp constants on - // compilation stage (ticket: 159589). + // outputs. Since "saturation" conversion during kernel runtime is more time consuming, current solution is + // clamp constants on compilation stage. key.doOutputSaturation = false; - for (size_t i = 0; i < getParentEdges().size(); i++) { - if (getParentEdgeAt(i)->getParent()->isConstant()) { - key.doOutputSaturation = true; - break; - } - } auto cache = context->getParamsCache(); auto result = cache->getOrCreate(key, buildExecutor); diff --git a/src/plugins/intel_cpu/src/nodes/input.cpp b/src/plugins/intel_cpu/src/nodes/input.cpp index ed6c8c570f6c15..efd8729ee22c7b 100644 --- a/src/plugins/intel_cpu/src/nodes/input.cpp +++ b/src/plugins/intel_cpu/src/nodes/input.cpp @@ -23,18 +23,18 @@ namespace node { #if defined(OPENVINO_ARCH_X86_64) namespace { -struct jit_has_subnormals_base : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_has_subnormals_base) +struct jit_subnormals_bf16saturation_check_base : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_subnormals_bf16saturation_check_base) typedef struct { const float* src; const size_t count; - bool hasSubnormals; + bool hasTargetValues; } args_t; typedef void (*fn_t)(const args_t*); - jit_has_subnormals_base() : jit_generator(jit_name()) { + jit_subnormals_bf16saturation_check_base() : jit_generator(jit_name()) { jit_ker_ = nullptr; } @@ -110,8 +110,35 @@ struct jit_has_subnormals_base : public jit_generator { uni_vtestps(b, c); // if ((!b & c) == 0) CF = 1 else CF = 0 } + void check_bf16_saturations(const Xbyak::Reg64& src, + const Xbyak::Ymm& bf16_max_mask, + const Xbyak::Ymm& bf16_min_mask) { + auto a = ymm1; + auto b = ymm2; + auto c = ymm3; + vmovdqu(a, yword[src]); // load 8 floats + vcmpps(b, a, bf16_max_mask, 0x1e); // b = (a > bf16_max) ? 1 : 0 + vcmpps(c, a, bf16_min_mask, 0x11); // c = (a < bf16_min) ? 1 : 0 + vorps(b, b, c); // b = b | c + vptest(b, b); // if (b != 0) CF = 1 else CF = 0 + } + + void check_bf16_saturations(const Xbyak::Reg64& src, + const Xbyak::Xmm& bf16_max_mask, + const Xbyak::Xmm& bf16_min_mask) { + auto a = xmm1; + auto b = xmm2; + auto c = xmm3; + + uni_vmovdqu(a, xword[src]); // load 4 floats + uni_vcmpps(b, a, bf16_max_mask, 0x1e); // b = (a > bf16_max) ? 1 : 0 + uni_vcmpps(c, a, bf16_max_mask, 0x11); // c = (a < bf16_min) ? 1 : 0 + uni_vorps(b, b, c); // b = b | c + uni_vtestps(b, b); // if (b != 0) CF = 1 else CF = 0 + } + protected: - Label exit, has_subnormals, no_subnormals; + Label exit, has_target_values, no_target_values; const Reg64& reg_src = rax; const Reg64& reg_dst = rbx; @@ -121,16 +148,35 @@ struct jit_has_subnormals_base : public jit_generator { static const uint32_t exponent_mask_data[8]; static const uint32_t mantissa_mask_data[8]; + static const float bf16_max_mask_data[8]; + static const float bf16_min_mask_data[8]; }; -const uint32_t jit_has_subnormals_base::exponent_mask_data[8] = +const uint32_t jit_subnormals_bf16saturation_check_base::exponent_mask_data[8] = {0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000}; -const uint32_t jit_has_subnormals_base::mantissa_mask_data[8] = +const uint32_t jit_subnormals_bf16saturation_check_base::mantissa_mask_data[8] = {0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff}; +const float jit_subnormals_bf16saturation_check_base::bf16_max_mask_data[8] = {3.38953139e+38f, + 3.38953139e+38f, + 3.38953139e+38f, + 3.38953139e+38f, + 3.38953139e+38f, + 3.38953139e+38f, + 3.38953139e+38f, + 3.38953139e+38f}; + +const float jit_subnormals_bf16saturation_check_base::bf16_min_mask_data[8] = {-3.38953139e+38f, + -3.38953139e+38f, + -3.38953139e+38f, + -3.38953139e+38f, + -3.38953139e+38f, + -3.38953139e+38f, + -3.38953139e+38f, + -3.38953139e+38f}; template -struct jit_has_subnormals : public jit_has_subnormals_base { +struct jit_has_subnormals : public jit_subnormals_bf16saturation_check_base { using Vmm = typename dnnl::impl::utils::conditional::type; const Vmm rmm4 = Vmm(4); @@ -150,7 +196,7 @@ struct jit_has_subnormals : public jit_has_subnormals_base { // Get arguments addresses mov(reg_src, ptr[param1 + offsetof(args_t, src)]); - lea(reg_dst, ptr[param1 + offsetof(args_t, hasSubnormals)]); + lea(reg_dst, ptr[param1 + offsetof(args_t, hasTargetValues)]); mov(reg_sz, ptr[param1 + offsetof(args_t, count)]); // Initialize necessary consts @@ -167,7 +213,7 @@ struct jit_has_subnormals : public jit_has_subnormals_base { foreach (reg_idx, 1, r8, [&, this](const Xbyak::Reg64& idx) { check_subnormals(reg_src, exponent_mask, mantissa_mask, zero); - jnc(has_subnormals); + jnc(has_target_values); add(reg_src, sizeof(float) * vlen); }) ; @@ -186,16 +232,16 @@ struct jit_has_subnormals : public jit_has_subnormals_base { copy_floats(r8, reg_src, reg_sz); check_subnormals(r8, exponent_mask, mantissa_mask, zero); - jc(no_subnormals); + jc(no_target_values); add(rsp, vlen * sizeof(float)); - L(has_subnormals); + L(has_target_values); mov(rax, 1); mov(byte[reg_dst], al); jmp(exit); - L(no_subnormals); + L(no_target_values); add(rsp, vlen * sizeof(float)); L(exit); @@ -203,8 +249,81 @@ struct jit_has_subnormals : public jit_has_subnormals_base { postamble(); } }; +template +struct jit_has_bf16_overflows : public jit_subnormals_bf16saturation_check_base { + using Vmm = typename dnnl::impl::utils::conditional::type; + + const Vmm rmm4 = Vmm(4); + const Vmm rmm5 = Vmm(5); + const Vmm rmm6 = Vmm(6); + const int length = isa == sse41 ? 4 : 8; + + void generate() override final { // NOLINT + size_t const vlen = length; + const int sh_bits = std::ilogb(vlen); + + auto zero = rmm4; + auto bf16_max_mask = rmm5; + auto bf16_min_mask = rmm6; + + preamble(); + + // Get arguments addresses + mov(reg_src, ptr[param1 + offsetof(args_t, src)]); + lea(reg_dst, ptr[param1 + offsetof(args_t, hasTargetValues)]); + mov(reg_sz, ptr[param1 + offsetof(args_t, count)]); + + // Initialize necessary consts + uni_vpxor(zero, zero, zero); + mov(reg_mask_addr, (size_t)bf16_max_mask_data); + uni_vmovdqu(bf16_max_mask, ptr[reg_mask_addr]); + mov(reg_mask_addr, (size_t)bf16_min_mask_data); + uni_vmovdqu(bf16_min_mask, ptr[reg_mask_addr]); + + // Main loop + xor_(reg_idx, reg_idx); + mov(r8, reg_sz); + shr(r8, sh_bits); + + foreach (reg_idx, 1, r8, [&, this](const Xbyak::Reg64& idx) { + check_bf16_saturations(reg_src, bf16_max_mask, bf16_min_mask); + jnz(has_target_values, T_NEAR); + add(reg_src, sizeof(float) * vlen); + }) + ; + + // Tail + shl(reg_idx, sh_bits); + sub(reg_sz, reg_idx); + test(reg_sz, reg_sz); + jz(exit); -jit_has_subnormals_base::fn_t jit_has_subnormals_function() { + // use space on stack for 4 or 8 floats + sub(rsp, vlen * sizeof(float)); + mov(r8, rsp); + + uni_vmovdqu(ptr[r8], zero); + + copy_floats(r8, reg_src, reg_sz); + check_bf16_saturations(r8, bf16_max_mask, bf16_min_mask); + jz(no_target_values, T_NEAR); + add(rsp, vlen * sizeof(float)); + + L(has_target_values); + + mov(rax, 1); + mov(byte[reg_dst], al); + jmp(exit); + + L(no_target_values); + add(rsp, vlen * sizeof(float)); + + L(exit); + + postamble(); + } +}; +jit_subnormals_bf16saturation_check_base::fn_t jit_has_subnormals_function() { if (mayiuse(cpu_isa_t::avx2)) { static jit_has_subnormals generator; static auto fn = generator.get(); @@ -216,6 +335,18 @@ jit_has_subnormals_base::fn_t jit_has_subnormals_function() { } return nullptr; } +jit_subnormals_bf16saturation_check_base::fn_t jit_has_bf16_overflows_function() { + if (mayiuse(cpu_isa_t::avx2)) { + static jit_has_bf16_overflows generator; + static auto fn = generator.get(); + return fn; + } else if (mayiuse(cpu_isa_t::sse41)) { + static jit_has_bf16_overflows generator; + static auto fn = generator.get(); + return fn; + } + return nullptr; +} } // namespace #endif @@ -271,49 +402,69 @@ void Input::cloneBlobIfRequired() { if (!size) return; - const float bf16_max = 3.3895313899137927e38f; + const bool do_bf16_saturation_check = + (context->getConfig().inferencePrecision == ov::element::bf16) ? true : false; #if defined(OPENVINO_ARCH_X86_64) - if (auto fn = jit_has_subnormals_function()) { + auto fn = jit_has_subnormals_function(); + auto fn_bf16_check = jit_has_bf16_overflows_function(); + if (fn && fn_bf16_check) { static const size_t batch_size = 2048; const size_t iterations_num = size / batch_size + 1; volatile bool has_subnormals_local = false; + volatile bool has_bf16_overflows_local = false; parallel_for(iterations_num, [&](int n) { auto ptr = u32data + n * batch_size; - const jit_has_subnormals_base::args_t args = {reinterpret_cast(ptr), - std::min(batch_size, (size_t)(u32data + size - ptr)), - false}; + const jit_subnormals_bf16saturation_check_base::args_t args1 = { + reinterpret_cast(ptr), + std::min(batch_size, (size_t)(u32data + size - ptr)), + false}; - fn(&args); + fn(&args1); - if (args.hasSubnormals) + if (args1.hasTargetValues) has_subnormals_local = true; }); - has_subnormals = has_subnormals_local; - //TODO: opt with jit - for (size_t i = 0; i < size; ++i) { - if (f32data[i] < -bf16_max || f32data[i] > bf16_max) { - has_bf16_overflows = true; - return; - } + if (do_bf16_saturation_check) { + parallel_for(iterations_num, [&](int n) { + auto ptr2 = f32data + n * batch_size; + const jit_subnormals_bf16saturation_check_base::args_t args2 = { + reinterpret_cast(ptr2), + std::min(batch_size, (size_t)(f32data + size - ptr2)), + false}; + + fn_bf16_check(&args2); + + if (args2.hasTargetValues) + has_bf16_overflows_local = true; + }); } + + has_subnormals = has_subnormals_local; + has_bf16_overflows = has_bf16_overflows_local; + return; } #endif uint32_t mantissaMask = 0x007fffff; uint32_t exponentMask = 0x7f800000; + const float bf16_max = 3.3895313899137927e38f; for (size_t i = 0; i < size; ++i) { if ((u32data[i] & exponentMask) == 0 && (u32data[i] & mantissaMask) != 0) { has_subnormals = true; } - if (f32data[i] < -bf16_max || f32data[i] > bf16_max) { - has_bf16_overflows = true; - } - if (has_subnormals && has_bf16_overflows) { + if (do_bf16_saturation_check) { + if (f32data[i] < -bf16_max || f32data[i] > bf16_max) { + has_bf16_overflows = true; + } + if (has_subnormals && has_bf16_overflows) { + return; + } + } else if (has_subnormals) { return; } }