Skip to content

Commit

Permalink
[CPU][ARM64] Add a JIT emitter for SoftPlus operation
Browse files Browse the repository at this point in the history
- Added a jit_sqrt_emitter derived class in
  aarch64/jit_eltwise_emitters
- Created entry Algorithm::EltwiseSqrt in the
  get_supported_precisions in nodes/kernels/aarch64
- Add the EltwiseSqrt entry in the aarch64 executors
  supported algorithms
- Add the ActivationType::Sqrt in the getPrimitiveType
  in activations

Closes: #24109

Signed-off-by: Nashez Zubair <[email protected]>
  • Loading branch information
nashez committed Nov 24, 2024
1 parent 287ab98 commit c812fc3
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2024,6 +2024,75 @@ std::set<std::vector<element::Type>> jit_sigmoid_emitter::get_supported_precisio
return {{element::f32}};
}

/// SOFT_PLUS ///
jit_soft_plus_emitter::jit_soft_plus_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node)
: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {
prepare_table();
exp_emitter = std::make_unique<jit_exp_emitter>(h, host_isa, exec_prc);
}

jit_soft_plus_emitter::jit_soft_plus_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) {
prepare_table();
exp_emitter = std::make_unique<jit_exp_emitter>(h, host_isa, exec_prc);
}

size_t jit_soft_plus_emitter::get_inputs_count() const { return 1; }

size_t jit_soft_plus_emitter::get_aux_vecs_count() const { return exp_emitter->get_aux_vecs_count() + 2; }

size_t jit_soft_plus_emitter::get_aux_gprs_count() const { return exp_emitter->get_aux_gprs_count() + 1; }

void jit_soft_plus_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
} else {
OPENVINO_THROW("Can't create jit eltwise kernel");
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_soft_plus_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (exec_prc_ != ov::element::f32) {
OPENVINO_THROW("unsupported precision: " + exec_prc_.to_string());
}

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
const TReg src(in_vec_idxs[0]);
const TReg dst(out_vec_idxs[0]);
const TReg aux1(aux_vec_idxs[exp_emitter->get_aux_vecs_count()]);
const TReg aux2(aux_vec_idxs[exp_emitter->get_aux_vecs_count() + 1]);

exp_emitter->emit_code(
{ src.getIdx() },
out_vec_idxs,
aux_vec_idxs,
aux_gpr_idxs);
h->ld1r(aux1.s, table_val2("one"));
h->fadd(dst.s, dst.s, aux1.s);
h->fcvtzs(aux2.s, dst.s);
h->cls(aux1.s, aux2.s);
h->ld1r(aux2.s, table_val("bit_count"));
h->fsub(aux1.s, aux2.s, aux1.s);
// aux1.s contains nearest power of 2 for e^x + 1
h->ld1r(aux2.s, table_val("ln2f"));
h->fmul(aux2.s, aux1.s, aux2.s); // Computed n*ln2 in aux2.s
h->fsub(dst.s, dst.s);
}

void jit_soft_plus_emitter::register_table_entries() {
push_arg_entry_of("one", 0x3f800000, true);
push_arg_entry_of("threshold", 0x41a00000, true); // Threshold set to 20
push_arg_entry_of("ln2f", 0x3f317218, true); // Natural log of 2
}

std::set<std::vector<element::Type>> jit_soft_plus_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
return {{element::f32}};
}

/// SOFT_SIGN ///
jit_soft_sign_emitter::jit_soft_sign_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,35 @@ class jit_sigmoid_emitter : public jit_emitter {
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};

class jit_soft_plus_emitter : public jit_emitter {
public:
jit_soft_plus_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc = ov::element::f32);

jit_soft_plus_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node);

size_t get_inputs_count() const override;

size_t get_aux_vecs_count() const override;

size_t get_aux_gprs_count() const override;

void register_table_entries() override;

static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ov::Node>& node = nullptr);

private:
std::unique_ptr<jit_exp_emitter> exp_emitter;

void emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const override;

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};

class jit_soft_sign_emitter : public jit_emitter {
public:
jit_soft_sign_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ bool JitEltwiseExecutor::isSupported(
Algorithm::EltwiseRelu,
Algorithm::EltwiseSelect,
Algorithm::EltwiseSigmoid,
Algorithm::EltwiseSoftPlus,
Algorithm::EltwiseSoftSign,
Algorithm::EltwiseSqrt,
Algorithm::EltwiseSubtract,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ std::shared_ptr<jit_emitter> jit_uni_eltwise_generic<isa>::create_eltwise_emitte
OV_CASE(Algorithm::EltwiseRelu, ov::intel_cpu::aarch64::jit_relu_emitter),
OV_CASE(Algorithm::EltwiseSelect, ov::intel_cpu::aarch64::jit_select_emitter),
OV_CASE(Algorithm::EltwiseSigmoid, ov::intel_cpu::aarch64::jit_sigmoid_emitter),
OV_CASE(Algorithm::EltwiseSoftPlus, ov::intel_cpu::aarch64::jit_soft_plus_emitter),
OV_CASE(Algorithm::EltwiseSoftSign, ov::intel_cpu::aarch64::jit_soft_sign_emitter),
OV_CASE(Algorithm::EltwiseSqrt, ov::intel_cpu::aarch64::jit_sqrt_emitter),
OV_CASE(Algorithm::EltwiseSubtract, ov::intel_cpu::aarch64::jit_subtract_emitter),
Expand Down Expand Up @@ -851,6 +852,7 @@ std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_pre
OV_CASE(Algorithm::EltwisePowerStatic, jit_power_static_emitter),
OV_CASE(Algorithm::EltwiseSelect, jit_select_emitter),
OV_CASE(Algorithm::EltwiseSigmoid, jit_sigmoid_emitter),
OV_CASE(Algorithm::EltwiseSoftPlus, jit_soft_plus_emitter),
OV_CASE(Algorithm::EltwiseSoftSign, jit_soft_sign_emitter),
OV_CASE(Algorithm::EltwiseSqrt, jit_sqrt_emitter),
OV_CASE(Algorithm::EltwiseSubtract, jit_subtract_emitter),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ std::string ActivationLayerCPUTest::getPrimitiveType(const utils::ActivationType
(activation_type == utils::ActivationTypes::GeluTanh) ||
(activation_type == utils::ActivationTypes::Relu) ||
(activation_type == utils::ActivationTypes::Sigmoid) ||
(activation_type == utils::ActivationTypes::SoftPlus) ||
(activation_type == utils::ActivationTypes::SoftSign) ||
(activation_type == utils::ActivationTypes::Sqrt) ||
(activation_type == utils::ActivationTypes::Swish) ||
Expand Down

0 comments on commit c812fc3

Please sign in to comment.