Skip to content

Commit

Permalink
Fix arm vfma inlining by using special _vfp4 dup fns.
Browse files Browse the repository at this point in the history
Some VFMA functions have `target_feature(enable = "vfp4")` while the called functions `vdup_n_f32` and `vdupq_n_f32` are `target_feature(enable = "v7")`. LLVM does not inline the functions due to the different feature flags. Using private _vfp4 variants of those functions allows them to be inlined.
  • Loading branch information
hkratz committed Sep 19, 2021
1 parent 3944043 commit 9e33ce7
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 7 deletions.
8 changes: 4 additions & 4 deletions crates/core_arch/src/arm_shared/neon/generated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8757,7 +8757,7 @@ vfmaq_f32_(b, c, a)
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfma))]
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmla))]
pub unsafe fn vfma_n_f32(a: float32x2_t, b: float32x2_t, c: f32) -> float32x2_t {
vfma_f32(a, b, vdup_n_f32(c))
vfma_f32(a, b, vdup_n_f32_vfp4(c))
}

/// Floating-point fused Multiply-Add to accumulator(vector)
Expand All @@ -8767,7 +8767,7 @@ pub unsafe fn vfma_n_f32(a: float32x2_t, b: float32x2_t, c: f32) -> float32x2_t
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfma))]
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmla))]
pub unsafe fn vfmaq_n_f32(a: float32x4_t, b: float32x4_t, c: f32) -> float32x4_t {
vfmaq_f32(a, b, vdupq_n_f32(c))
vfmaq_f32(a, b, vdupq_n_f32_vfp4(c))
}

/// Floating-point fused multiply-subtract from accumulator
Expand Down Expand Up @@ -8799,7 +8799,7 @@ pub unsafe fn vfmsq_f32(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfms))]
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmls))]
pub unsafe fn vfms_n_f32(a: float32x2_t, b: float32x2_t, c: f32) -> float32x2_t {
vfms_f32(a, b, vdup_n_f32(c))
vfms_f32(a, b, vdup_n_f32_vfp4(c))
}

/// Floating-point fused Multiply-subtract to accumulator(vector)
Expand All @@ -8809,7 +8809,7 @@ pub unsafe fn vfms_n_f32(a: float32x2_t, b: float32x2_t, c: f32) -> float32x2_t
#[cfg_attr(all(test, target_arch = "arm"), assert_instr(vfms))]
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(fmls))]
pub unsafe fn vfmsq_n_f32(a: float32x4_t, b: float32x4_t, c: f32) -> float32x4_t {
vfmsq_f32(a, b, vdupq_n_f32(c))
vfmsq_f32(a, b, vdupq_n_f32_vfp4(c))
}

/// Subtract
Expand Down
26 changes: 26 additions & 0 deletions crates/core_arch/src/arm_shared/neon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3786,6 +3786,19 @@ pub unsafe fn vdupq_n_f32(value: f32) -> float32x4_t {
float32x4_t(value, value, value, value)
}

/// Duplicate vector element to vector or scalar
///
/// Private vfp4 version used by FMA intriniscs because LLVM does
/// not inline the non-vfp4 version in vfp4 functions.
#[inline]
#[target_feature(enable = "neon")]
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
#[cfg_attr(all(test, target_arch = "arm"), assert_instr("vdup.32"))]
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(dup))]
unsafe fn vdupq_n_f32_vfp4(value: f32) -> float32x4_t {
float32x4_t(value, value, value, value)
}

/// Duplicate vector element to vector or scalar
#[inline]
#[target_feature(enable = "neon")]
Expand Down Expand Up @@ -3896,6 +3909,19 @@ pub unsafe fn vdup_n_f32(value: f32) -> float32x2_t {
float32x2_t(value, value)
}

/// Duplicate vector element to vector or scalar
///
/// Private vfp4 version used by FMA intriniscs because LLVM does
/// not inline the non-vfp4 version in vfp4 functions.
#[inline]
#[target_feature(enable = "neon")]
#[cfg_attr(target_arch = "arm", target_feature(enable = "vfp4"))]
#[cfg_attr(all(test, target_arch = "arm"), assert_instr("vdup.32"))]
#[cfg_attr(all(test, target_arch = "aarch64"), assert_instr(dup))]
unsafe fn vdup_n_f32_vfp4(value: f32) -> float32x2_t {
float32x2_t(value, value)
}

/// Duplicate vector element to vector or scalar
#[inline]
#[target_feature(enable = "neon")]
Expand Down
4 changes: 2 additions & 2 deletions crates/stdarch-gen/neon.spec
Original file line number Diff line number Diff line change
Expand Up @@ -2741,7 +2741,7 @@ generate float*_t
/// Floating-point fused Multiply-Add to accumulator(vector)
name = vfma
n-suffix
multi_fn = vfma-self-noext, a, b, {vdup-nself-noext, c}
multi_fn = vfma-self-noext, a, b, {vdup-nselfvfp4-noext, c}
a = 2.0, 3.0, 4.0, 5.0
b = 6.0, 4.0, 7.0, 8.0
c = 8.0
Expand Down Expand Up @@ -2818,7 +2818,7 @@ generate float*_t
/// Floating-point fused Multiply-subtract to accumulator(vector)
name = vfms
n-suffix
multi_fn = vfms-self-noext, a, b, {vdup-nself-noext, c}
multi_fn = vfms-self-noext, a, b, {vdup-nselfvfp4-noext, c}
a = 50.0, 35.0, 60.0, 69.0
b = 6.0, 4.0, 7.0, 8.0
c = 8.0
Expand Down
12 changes: 11 additions & 1 deletion crates/stdarch-gen/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,7 @@ fn gen_aarch64(
out_t,
fixed,
None,
true,
));
}
calls
Expand Down Expand Up @@ -1920,6 +1921,7 @@ fn gen_arm(
out_t,
fixed,
None,
false,
));
}
calls
Expand Down Expand Up @@ -2287,6 +2289,7 @@ fn get_call(
out_t: &str,
fixed: &Vec<String>,
n: Option<i32>,
aarch64: bool,
) -> String {
let params: Vec<_> = in_str.split(',').map(|v| v.trim().to_string()).collect();
assert!(params.len() > 0);
Expand Down Expand Up @@ -2454,7 +2457,8 @@ fn get_call(
in_t,
out_t,
fixed,
Some(i as i32)
Some(i as i32),
aarch64
)
);
call.push_str(&sub_match);
Expand Down Expand Up @@ -2503,6 +2507,7 @@ fn get_call(
out_t,
fixed,
n.clone(),
aarch64,
);
if !param_str.is_empty() {
param_str.push_str(", ");
Expand Down Expand Up @@ -2573,6 +2578,11 @@ fn get_call(
fn_name.push_str(type_to_suffix(in_t[1]));
} else if fn_format[1] == "nself" {
fn_name.push_str(type_to_n_suffix(in_t[1]));
} else if fn_format[1] == "nselfvfp4" {
fn_name.push_str(type_to_n_suffix(in_t[1]));
if !aarch64 {
fn_name.push_str("_vfp4");
}
} else if fn_format[1] == "out" {
fn_name.push_str(type_to_suffix(out_t));
} else if fn_format[1] == "in0" {
Expand Down

0 comments on commit 9e33ce7

Please sign in to comment.