Skip to content

Commit

Permalink
[CPU] Implemented "jit_exp_emitter" (openvinotoolkit#26974)
Browse files Browse the repository at this point in the history
### Details:
- *Previously, we used dnnl-injector for `Exp` op which require 2
`aux_vec_regs`. The snippets kernel have some pool of aux vec registers
which can be used by emitters in their implementations. However, dnnl
cannot work with user-provided aux registers and always spill them on
stack while plugin emitters can do it. To avoid extra push-pop in
Snippets kernel (it leads to performance degradations), we implemented
own emitter for `Exp` with the same logic to have opportunity to pass
free aux vec registers*
- *Updated `jit_erf_emitter`: reused new `jit_exp_emitter` to compute
exponent and now we work only with `vmm_dst` to avoid `vmm_src` data
corruption (input registers must not be corrupted)*

### Tickets:
 - *155236*
  • Loading branch information
a-sidorova authored Oct 22, 2024
1 parent d6dc495 commit e3ad821
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 124 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,6 @@ class jit_elu_emitter : public jit_dnnl_emitter {
}
};

class jit_exp_emitter : public jit_dnnl_emitter {
public:
jit_exp_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr<ov::Node>& n,
ov::element::Type exec_prc = ov::element::f32)
: jit_dnnl_emitter(host, host_isa, n, exec_prc) {
kind = dnnl_eltwise_exp;
alpha = 0.f;
beta = 0.f;

set_injector();
}
};

class jit_abs_emitter : public jit_dnnl_emitter {
public:
jit_abs_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr<ov::Node>& n,
Expand Down
252 changes: 145 additions & 107 deletions src/plugins/intel_cpu/src/emitters/plugin/x64/jit_eltwise_emitters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1822,29 +1822,25 @@ void jit_negative_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, cons
h->uni_vsubps(vmm_dst, vmm_dst, vmm_src);
}

/// ERF ///
jit_erf_emitter::jit_erf_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc)

/// EXP ///
jit_exp_emitter::jit_exp_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc)
: jit_emitter(host, host_isa, exec_prc) {
prepare_table();
}

jit_erf_emitter::jit_erf_emitter(x64::jit_generator* host,
x64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node,
ov::element::Type exec_prc)
jit_exp_emitter::jit_exp_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, const std::shared_ptr<ov::Node>& node, ov::element::Type exec_prc)
: jit_emitter(host, host_isa, exec_prc) {
prepare_table();
}

size_t jit_erf_emitter::get_inputs_num() const { return 1; }
size_t jit_exp_emitter::get_inputs_num() const { return 1; }

std::set<std::vector<element::Type>> jit_erf_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
std::set<std::vector<element::Type>> jit_exp_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
return {{element::f32}};
}

void jit_erf_emitter::emit_impl(
const std::vector<size_t> &in_vec_idxs,
const std::vector<size_t> &out_vec_idxs) const {
void jit_exp_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (host_isa_ == x64::sse41) {
emit_isa<x64::sse41>(in_vec_idxs, out_vec_idxs);
} else if (host_isa_ == x64::avx2) {
Expand All @@ -1857,20 +1853,16 @@ void jit_erf_emitter::emit_impl(
}

template <x64::cpu_isa_t isa>
void jit_erf_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
void jit_exp_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
using Vmm = typename conditional3<isa == x64::sse41, Xmm, isa == x64::avx2, Ymm, Zmm>::type;
Vmm vmm_src = Vmm(in_vec_idxs[0]);
Vmm vmm_dst = Vmm(out_vec_idxs[0]);

Vmm vmm_mask = Vmm(aux_vec_idxs[0]);
Vmm vmm_aux0 = Vmm(aux_vec_idxs[0]);
Vmm vmm_aux1 = Vmm(aux_vec_idxs[1]);
Vmm vmm_aux2 = Vmm(aux_vec_idxs[2]);
Vmm vmm_aux3 = Vmm(aux_vec_idxs[3]);
Vmm vmm_aux4 = Vmm(aux_vec_idxs[4]);
Vmm vmm_mask = need_vmm_mask() ? Vmm(aux_vec_idxs[0]) : Vmm();
Vmm vmm_aux0 = Vmm(aux_vec_idxs[0 + static_cast<size_t>(need_vmm_mask())]);
Vmm vmm_aux1 = Vmm(aux_vec_idxs[1 + static_cast<size_t>(need_vmm_mask())]);

auto compute_cmp_mask = [&](const Vmm &vmm_src,
const Xbyak::Operand &compare_operand, int cmp_predicate) {
auto compute_cmp_mask = [&](const Vmm &vmm_src, const Xbyak::Operand &compare_operand, int cmp_predicate) {
if (host_isa_ == x64::avx512_core) {
h->vcmpps(k_mask, vmm_src, compare_operand, cmp_predicate);
} else {
Expand All @@ -1886,66 +1878,123 @@ void jit_erf_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std
}
};

auto exp_compute_vector_fwd = [&](const Vmm &vmm_src) {
// get mask of values lower than log(FLT_MIN) to zero them in the output
compute_cmp_mask(vmm_src, table_val("exp_ln_flt_min_f"), _cmp_lt_os);

h->uni_vminps(vmm_src, vmm_src, table_val("exp_ln_flt_max_f"));
h->uni_vmaxps(vmm_src, vmm_src, table_val("exp_ln_flt_min_f"));
h->uni_vmovups(vmm_aux1, vmm_src);

// calculate exp(x)
// fx = x * log2ef + 0.5
h->uni_vmulps(vmm_src, vmm_src, table_val("exp_log2ef"));
h->uni_vaddps(vmm_src, vmm_src, table_val("half"));

// tmp = floorf(fx)
const auto _op_floor = 1u;
h->uni_vroundps(vmm_aux2, vmm_src, _op_floor);

// keep vmm_src = fx for further computations
h->uni_vmovups(vmm_src, vmm_aux2);

// x = x - fx * ln2
h->uni_vfnmadd231ps(vmm_aux1, vmm_aux2, table_val("ln2f"));

// compute 2^n
h->uni_vcvtps2dq(vmm_aux2, vmm_src);
h->uni_vpaddd(vmm_aux2, vmm_aux2, table_val("exponent_bias"));
const int n_mantissa_bits = 23;
h->uni_vpslld(vmm_aux2, vmm_aux2, n_mantissa_bits); //Vmm(6) = 2^-fx

// use vmm_src as tmp vmm_zero when applying mask
h->uni_vpxor(vmm_src, vmm_src, vmm_src);
// set zeroes at those points which were < log(FLT_MIN)
blend_with_mask(vmm_aux2, vmm_src);

// compute polynomial
h->uni_vmovups(vmm_src, table_val("ex_pol5"));
h->uni_vfmadd213ps(vmm_src, vmm_aux1, table_val("ex_pol4"));
h->uni_vfmadd213ps(vmm_src, vmm_aux1, table_val("ex_pol3"));
h->uni_vfmadd213ps(vmm_src, vmm_aux1, table_val("ex_pol2"));
h->uni_vfmadd213ps(vmm_src, vmm_aux1, table_val("ex_pol1"));
h->uni_vfmadd213ps(vmm_src, vmm_aux1, table_val("one"));
// y = y * 2^n
h->uni_vmulps(vmm_src, vmm_src, vmm_aux2);
};
h->uni_vmovups(vmm_aux1, table_val("ln_flt_min_f"));
// get mask of values lower than log(FLT_MIN) to zero them in the output
compute_cmp_mask(vmm_src, vmm_aux1, _cmp_lt_os);

auto abs_compute_vector_fwd = [&](const Vmm &vmm_src) {
// compute abs(x) = _mm_and_ps(x, 01111..111));
h->uni_vandps(vmm_src, vmm_src, table_val("positive_mask"));
};
h->uni_vminps(vmm_dst, vmm_src, table_val("ln_flt_max_f"));
h->uni_vmaxps(vmm_dst, vmm_dst, vmm_aux1);
h->uni_vmovups(vmm_aux0, vmm_dst);

// calculate exp(x)
// fx = x * log2ef + 0.5
h->uni_vmulps(vmm_dst, vmm_dst, table_val("log2ef"));
h->uni_vaddps(vmm_dst, vmm_dst, table_val("half"));

// tmp = floorf(fx)
const auto _op_floor = 1u;
h->uni_vroundps(vmm_aux1, vmm_dst, _op_floor);

// keep vmm_dst = fx for further computations
h->uni_vmovups(vmm_dst, vmm_aux1);

// x = x - fx * ln2
h->uni_vfnmadd231ps(vmm_aux0, vmm_aux1, table_val("ln2f"));

// compute 2^n
h->uni_vcvtps2dq(vmm_aux1, vmm_dst);
h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val("exponent_bias"));
const int n_mantissa_bits = 23;
h->uni_vpslld(vmm_aux1, vmm_aux1, n_mantissa_bits);

// use vmm_dst as tmp vmm_zero when applying mask
h->uni_vpxor(vmm_dst, vmm_dst, vmm_dst);
// set zeroes at those points which were < log(FLT_MIN)
blend_with_mask(vmm_aux1, vmm_dst);

// compute polynomial
h->uni_vmovups(vmm_dst, table_val("pol5"));
h->uni_vfmadd213ps(vmm_dst, vmm_aux0, table_val("pol4"));
h->uni_vfmadd213ps(vmm_dst, vmm_aux0, table_val("pol3"));
h->uni_vfmadd213ps(vmm_dst, vmm_aux0, table_val("pol2"));
h->uni_vfmadd213ps(vmm_dst, vmm_aux0, table_val("pol1"));
h->uni_vfmadd213ps(vmm_dst, vmm_aux0, table_val("one"));
// y = y * 2^n
h->uni_vmulps(vmm_dst, vmm_dst, vmm_aux1);
}

void jit_exp_emitter::register_table_entries() {
push_arg_entry_of("pol1", 0x3f7ffffb, true); // p1 = 0.999999701f
push_arg_entry_of("pol2", 0x3efffee3, true); // p2 = 0.499991506f
push_arg_entry_of("pol3", 0x3e2aad40, true); // p3 = 0.166676521f
push_arg_entry_of("pol4", 0x3d2b9d0d, true); // p4 = 0.0418978221f
push_arg_entry_of("pol5", 0x3c07cfce, true); // p5 = 0.00828929059f

push_arg_entry_of("one", CONST_1_F, true);
push_arg_entry_of("half", 0x3f000000, true);
push_arg_entry_of("ln2f", 0x3f317218, true);
push_arg_entry_of("log2ef", 0x3fb8aa3b, true);
push_arg_entry_of("ln_flt_max_f", 0x42b17218, true);
push_arg_entry_of("ln_flt_min_f", 0xc2aeac50, true);
push_arg_entry_of("exponent_bias", 0x0000007f, true);
}

size_t jit_exp_emitter::aux_vecs_count() const {
return need_vmm_mask() ? 3 : 2;
}

/// ERF ///
jit_erf_emitter::jit_erf_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc)
: jit_emitter(host, host_isa, exec_prc) {
m_exp_emitter.reset(new jit_exp_emitter(host, host_isa, exec_prc));
prepare_table();
}

jit_erf_emitter::jit_erf_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, const std::shared_ptr<ov::Node>& node, ov::element::Type exec_prc)
: jit_erf_emitter(host, host_isa, exec_prc) {}

size_t jit_erf_emitter::get_inputs_num() const { return 1; }

std::set<std::vector<element::Type>> jit_erf_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
return {{element::f32}};
}

void jit_erf_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (host_isa_ == x64::sse41) {
emit_isa<x64::sse41>(in_vec_idxs, out_vec_idxs);
} else if (host_isa_ == x64::avx2) {
emit_isa<x64::avx2>(in_vec_idxs, out_vec_idxs);
} else if (host_isa_ == x64::avx512_core) {
emit_isa<x64::avx512_core>(in_vec_idxs, out_vec_idxs);
} else {
OV_CPU_JIT_EMITTER_THROW("Unsupported ISA ", host_isa_);
}
}

template <x64::cpu_isa_t isa>
void jit_erf_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
using Vmm = typename conditional3<isa == x64::sse41, Xmm, isa == x64::avx2, Ymm, Zmm>::type;
Vmm vmm_src = Vmm(in_vec_idxs[0]);
Vmm vmm_dst = Vmm(out_vec_idxs[0]);

Vmm vmm_aux0 = Vmm(aux_vec_idxs[0]);
Vmm vmm_aux1 = Vmm(aux_vec_idxs[1]);
Vmm vmm_aux2 = Vmm(aux_vec_idxs[2]);
Vmm vmm_aux3 = Vmm(aux_vec_idxs[3]);

// IMPORTANT: we use vmm_aux3 to save `x` as exp_compute does not use it.
h->uni_vmovups(vmm_aux3, vmm_src);

// -exp(-x*x)
h->uni_vmulps(vmm_src, vmm_src, vmm_src);
h->uni_vxorps(vmm_src, vmm_src, table_val("sign_mask"));
h->uni_vmulps(vmm_dst, vmm_src, vmm_src);
h->uni_vxorps(vmm_dst, vmm_dst, table_val("sign_mask"));

exp_compute_vector_fwd(vmm_src);
// pass the current `aux_vec_idxs` to `exp_emitter` excepting `vmm_aux3`
auto exp_aux_vec_idxs = aux_vec_idxs;
exp_aux_vec_idxs.erase(std::find(exp_aux_vec_idxs.begin(), exp_aux_vec_idxs.end(), static_cast<size_t>(vmm_aux3.getIdx())));
m_exp_emitter->emit_code({static_cast<size_t>(vmm_dst.getIdx())}, {static_cast<size_t>(vmm_dst.getIdx())}, exp_aux_vec_idxs);

h->uni_vxorps(vmm_src, vmm_src, table_val("sign_mask"));
h->uni_vxorps(vmm_dst, vmm_dst, table_val("sign_mask"));

// get sign
h->uni_vmovups(vmm_aux0, vmm_aux3);
Expand All @@ -1954,60 +2003,49 @@ void jit_erf_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std
// abs(x)
h->uni_vmovups(vmm_aux1, vmm_aux3);
// compute abs(x) = _mm_and_ps(x, 01111..111));
abs_compute_vector_fwd(vmm_aux1);
h->uni_vandps(vmm_aux1, vmm_aux1, table_val("positive_mask"));

// t = 1 / (p*x + 1)
h->uni_vmovups(vmm_aux2, table_val("approx_const"));
h->uni_vfmadd213ps(vmm_aux2, vmm_aux1, table_val("one"));
h->uni_vmovups(vmm_aux4, table_val("one"));
h->uni_vdivps(vmm_aux4, vmm_aux4, vmm_aux2);
h->uni_vmovups(vmm_aux3, table_val("one"));
h->uni_vdivps(vmm_aux3, vmm_aux3, vmm_aux2);

// -exp(-x*x)*t
h->uni_vmulps(vmm_src, vmm_src, vmm_aux4);
h->uni_vmulps(vmm_dst, vmm_dst, vmm_aux3);

// compute polynomialial r
h->uni_vmovups(vmm_aux1, table_val("erf_pol5"));
h->uni_vfmadd213ps(vmm_aux1, vmm_aux4, table_val("erf_pol4"));
h->uni_vfmadd213ps(vmm_aux1, vmm_aux4, table_val("erf_pol3"));
h->uni_vfmadd213ps(vmm_aux1, vmm_aux4, table_val("erf_pol2"));
h->uni_vfmadd213ps(vmm_aux1, vmm_aux4, table_val("erf_pol1"));
h->uni_vmovups(vmm_aux1, table_val("pol5"));
h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val("pol4"));
h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val("pol3"));
h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val("pol2"));
h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val("pol1"));

// erf = sign * (1 - r * t * exp(-x*x))
h->uni_vfmadd213ps(vmm_src, vmm_aux1, table_val("one"));
h->uni_vxorps(vmm_dst, vmm_src, vmm_aux0);
h->uni_vfmadd213ps(vmm_dst, vmm_aux1, table_val("one"));
h->uni_vxorps(vmm_dst, vmm_dst, vmm_aux0);
}

void jit_erf_emitter::register_table_entries() {
push_arg_entry_of("approx_const", 0x3ea7ba05, true); // 0.3275911
push_arg_entry_of("one_over_sqrt_two", 0x3f3504f3, true);
push_arg_entry_of("sign_mask", 0x80000000, true);

push_arg_entry_of("ex_pol1", 0x3f7ffffb, true); // p1 = 0.999999701f
push_arg_entry_of("ex_pol2", 0x3efffee3, true); // p2 = 0.499991506f
push_arg_entry_of("ex_pol3", 0x3e2aad40, true); // p3 = 0.166676521f
push_arg_entry_of("ex_pol4", 0x3d2b9d0d, true); // p4 = 0.0418978221f
push_arg_entry_of("ex_pol5", 0x3c07cfce, true); // p5 = 0.00828929059f

push_arg_entry_of("erf_pol1", 0x3e827906, true); // p1 = 0.254829592f
push_arg_entry_of("erf_pol2", 0xbe91a98e, true); // p2 = -0.284496736f
push_arg_entry_of("erf_pol3", 0x3fb5f0e3, true); // p3 = 1.421413741f
push_arg_entry_of("erf_pol4", 0xbfba00e3, true); // p4 = -1.453152027f
push_arg_entry_of("erf_pol5", 0x3f87dc22, true); // p5 = 1.061405429f

push_arg_entry_of("one", CONST_1_F, true);
push_arg_entry_of("half", 0x3f000000, true);

push_arg_entry_of("exp_log2ef", 0x3fb8aa3b, true);
push_arg_entry_of("exp_ln_flt_max_f", 0x42b17218, true);
push_arg_entry_of("exp_ln_flt_min_f", 0xc2aeac50, true);

push_arg_entry_of("ln2f", 0x3f317218, true);
push_arg_entry_of("exponent_bias", 0x0000007f, true);
push_arg_entry_of("sign_mask", 0x80000000, true);
push_arg_entry_of("positive_mask", 0x7fffffff, true);

push_arg_entry_of("pol1", 0x3e827906, true); // p1 = 0.254829592f
push_arg_entry_of("pol2", 0xbe91a98e, true); // p2 = -0.284496736f
push_arg_entry_of("pol3", 0x3fb5f0e3, true); // p3 = 1.421413741f
push_arg_entry_of("pol4", 0xbfba00e3, true); // p4 = -1.453152027f
push_arg_entry_of("pol5", 0x3f87dc22, true); // p5 = 1.061405429f
}

size_t jit_erf_emitter::aux_vecs_count() const {
return 5ul;
return 4ul;
}

void jit_erf_emitter::emit_data() const {
jit_emitter::emit_data();
m_exp_emitter->emit_data();
}

/// SOFT SIGN ///
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,29 @@ class jit_negative_emitter : public jit_emitter {
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};

class jit_exp_emitter : public jit_emitter {
public:
jit_exp_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa,
ov::element::Type exec_prc = ov::element::f32);

jit_exp_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr<ov::Node>& n,
ov::element::Type exec_prc = ov::element::f32);

size_t get_inputs_num() const override;
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ov::Node>& node = nullptr);

private:
void emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const override;

template <dnnl::impl::cpu::x64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;

bool need_vmm_mask() const { return host_isa_ != dnnl::impl::cpu::x64::avx512_core; }

void register_table_entries() override;
size_t aux_vecs_count() const override;
};

class jit_erf_emitter : public jit_emitter {
public:
jit_erf_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa,
Expand All @@ -533,6 +556,8 @@ class jit_erf_emitter : public jit_emitter {
jit_erf_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr<ov::Node>& n,
ov::element::Type exec_prc = ov::element::f32);

void emit_data() const override;

size_t get_inputs_num() const override;
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ov::Node>& node = nullptr);

Expand All @@ -546,6 +571,8 @@ class jit_erf_emitter : public jit_emitter {

void register_table_entries() override;
size_t aux_vecs_count() const override;

std::unique_ptr<jit_exp_emitter> m_exp_emitter {nullptr};
};

class jit_soft_sign_emitter : public jit_emitter {
Expand Down
Loading

0 comments on commit e3ad821

Please sign in to comment.