Skip to content

Commit

Permalink
NVPTX: Add f16 SIMD intrinsics
Browse files Browse the repository at this point in the history
  • Loading branch information
kjetilkjeka authored and Kjetil Kjeka committed Aug 12, 2024
1 parent a3beb09 commit 21ff034
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 0 deletions.
5 changes: 5 additions & 0 deletions crates/core_arch/src/nvptx/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@

use crate::ffi::c_void;

mod packed;

#[unstable(feature = "stdarch_nvptx", issue = "111199")]
pub use packed::*;

#[allow(improper_ctypes)]
extern "C" {
#[link_name = "llvm.nvvm.barrier0"]
Expand Down
93 changes: 93 additions & 0 deletions crates/core_arch/src/nvptx/packed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
//! NVPTX Packed data types (SIMD)
//!
//! Packed Data Types is what PTX calls SIMD types. See [PTX ISA (Packed Data Types)](https://docs.nvidia.com/cuda/parallel-thread-execution/#packed-data-types) for a full reference.

// Note: #[assert_instr] tests are not actually being run on nvptx due to being a `no_std` target incapable of running tests. Something like FileCheck would be appropriate for verifying the correct instruction is used.

use crate::intrinsics::simd::*;

#[allow(improper_ctypes)]
extern "C" {
#[link_name = "llvm.minnum.v2f16"]
fn llvm_f16x2_min(a: f16x2, b: f16x2) -> f16x2;
#[link_name = "llvm.maxnum.v2f16"]
fn llvm_f16x2_max(a: f16x2, b: f16x2) -> f16x2;
}

types! {
#![unstable(feature = "stdarch_nvptx", issue = "111199")]

/// PTX-specific 32-bit wide floating point (f16 x 2) vector type
pub struct f16x2(2 x f16);

}

/// Add two values
///
/// <https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-add>
#[inline]
#[cfg_attr(test, assert_instr(add.rn.f16x22))]
#[unstable(feature = "stdarch_nvptx", issue = "111199")]
pub unsafe fn f16x2_add(a: f16x2, b: f16x2) -> f16x2 {
simd_add(a, b)
}

/// Subtract two values
///
/// <https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-sub>
#[inline]
#[cfg_attr(test, assert_instr(sub.rn.f16x2))]
#[unstable(feature = "stdarch_nvptx", issue = "111199")]
pub unsafe fn f16x2_sub(a: f16x2, b: f16x2) -> f16x2 {
simd_sub(a, b)
}

/// Multiply two values
///
/// <https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-mul>
#[inline]
#[cfg_attr(test, assert_instr(mul.rn.f16x2))]
#[unstable(feature = "stdarch_nvptx", issue = "111199")]
pub unsafe fn f16x2_mul(a: f16x2, b: f16x2) -> f16x2 {
simd_mul(a, b)
}

/// Fused multiply-add
///
/// <https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-fma>
#[inline]
#[cfg_attr(test, assert_instr(fma.rn.f16x2))]
#[unstable(feature = "stdarch_nvptx", issue = "111199")]
pub unsafe fn f16x2_fma(a: f16x2, b: f16x2, c: f16x2) -> f16x2 {
simd_fma(a, b, c)
}

/// Arithmetic negate
///
/// <https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-neg>
#[inline]
#[cfg_attr(test, assert_instr(neg.f16x2))]
#[unstable(feature = "stdarch_nvptx", issue = "111199")]
pub unsafe fn f16x2_neg(a: f16x2) -> f16x2 {
simd_neg(a)
}

/// Find the minimum of two values
///
/// <https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-min>
#[inline]
#[cfg_attr(test, assert_instr(min.f16x2))]
#[unstable(feature = "stdarch_nvptx", issue = "111199")]
pub unsafe fn f16x2_min(a: f16x2, b: f16x2) -> f16x2 {
llvm_f16x2_min(a, b)
}

/// Find the maximum of two values
///
/// <https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-max>
#[inline]
#[cfg_attr(test, assert_instr(max.f16x2))]
#[unstable(feature = "stdarch_nvptx", issue = "111199")]
pub unsafe fn f16x2_max(a: f16x2, b: f16x2) -> f16x2 {
llvm_f16x2_max(a, b)
}

0 comments on commit 21ff034

Please sign in to comment.