Skip to content

Commit

Permalink
move __tilecfg to tests mod & change void * to *mut u8
Browse files Browse the repository at this point in the history
  • Loading branch information
ziyizhang-1 committed Jul 27, 2024
1 parent fd48ea1 commit d3d49fb
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 85 deletions.
125 changes: 83 additions & 42 deletions crates/core_arch/src/x86_64/amx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#[inline]
#[target_feature(enable = "amx-tile")]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_loadconfig(mem_addr: *const i8) {
pub unsafe fn _tile_loadconfig(mem_addr: *const u8) {
ldtilecfg(mem_addr);
}

Expand All @@ -20,7 +20,7 @@ pub unsafe fn _tile_loadconfig(mem_addr: *const i8) {
#[inline]
#[target_feature(enable = "amx-tile")]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_storeconfig(mem_addr: *mut i8) {
pub unsafe fn _tile_storeconfig(mem_addr: *mut u8) {
sttilecfg(mem_addr);
}

Expand All @@ -31,7 +31,7 @@ pub unsafe fn _tile_storeconfig(mem_addr: *mut i8) {
#[rustc_legacy_const_generics(0)]
#[target_feature(enable = "amx-tile")]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_loadd<const DST: i8>(base: *const i8, stride: usize) {
pub unsafe fn _tile_loadd<const DST: i8>(base: *const u8, stride: usize) {
static_assert_uimm_bits!(DST, 3);
tileloadd64(DST, base, stride);
}
Expand All @@ -53,7 +53,7 @@ pub unsafe fn _tile_release() {
#[rustc_legacy_const_generics(0)]
#[target_feature(enable = "amx-tile")]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_stored<const DST: i8>(base: *mut i8, stride: usize) {
pub unsafe fn _tile_stored<const DST: i8>(base: *mut u8, stride: usize) {
static_assert_uimm_bits!(DST, 3);
tilestored64(DST, base, stride);
}
Expand All @@ -67,7 +67,7 @@ pub unsafe fn _tile_stored<const DST: i8>(base: *mut i8, stride: usize) {
#[rustc_legacy_const_generics(0)]
#[target_feature(enable = "amx-tile")]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_stream_loadd<const DST: i8>(base: *const i8, stride: usize) {
pub unsafe fn _tile_stream_loadd<const DST: i8>(base: *const u8, stride: usize) {
static_assert_uimm_bits!(DST, 3);
tileloaddt164(DST, base, stride);
}
Expand Down Expand Up @@ -227,17 +227,17 @@ pub unsafe fn _tile_cmmrlfp16ps<const DST: i8, const A: i8, const B: i8>() {
#[allow(improper_ctypes)]
extern "C" {
#[link_name = "llvm.x86.ldtilecfg"]
fn ldtilecfg(mem_addr: *const i8);
fn ldtilecfg(mem_addr: *const u8);
#[link_name = "llvm.x86.sttilecfg"]
fn sttilecfg(mem_addr: *mut i8);
fn sttilecfg(mem_addr: *mut u8);
#[link_name = "llvm.x86.tileloadd64"]
fn tileloadd64(dst: i8, base: *const i8, stride: usize);
fn tileloadd64(dst: i8, base: *const u8, stride: usize);
#[link_name = "llvm.x86.tileloaddt164"]
fn tileloaddt164(dst: i8, base: *const i8, stride: usize);
fn tileloaddt164(dst: i8, base: *const u8, stride: usize);
#[link_name = "llvm.x86.tilerelease"]
fn tilerelease();
#[link_name = "llvm.x86.tilestored64"]
fn tilestored64(dst: i8, base: *mut i8, stride: usize);
fn tilestored64(dst: i8, base: *mut u8, stride: usize);
#[link_name = "llvm.x86.tilezero"]
fn tilezero(dst: i8);
#[link_name = "llvm.x86.tdpbf16ps"]
Expand Down Expand Up @@ -267,6 +267,47 @@ mod tests {
#[cfg(target_os = "linux")]
use syscalls::{syscall, Sysno};

#[allow(non_camel_case_types)]
#[repr(packed)]
#[derive(Copy, Clone, Default, Debug, PartialEq)]
struct __tilecfg {
/// 0 `or` 1
palette: u8,
start_row: u8,
/// reserved, must be zero
reserved_a0: [u8; 14],
/// number of bytes of one row in each tile
colsb: [u16; 8],
/// reserved, must be zero
reserved_b0: [u16; 8],
/// number of rows in each tile
rows: [u8; 8],
/// reserved, must be zero
reserved_c0: [u8; 8],
}

impl __tilecfg {
fn new(palette: u8, start_row: u8, colsb: [u16; 8], rows: [u8; 8]) -> Self {
Self {
palette,
start_row,
reserved_a0: [0u8; 14],
colsb,
reserved_b0: [0u16; 8],
rows,
reserved_c0: [0u8; 8],
}
}

const fn as_ptr(&self) -> *const u8 {
self as *const Self as *const u8
}

fn as_mut_ptr(&mut self) -> *mut u8 {
self as *mut Self as *mut u8
}
}

#[cfg(not(target_os = "linux"))]
#[target_feature(enable = "amx-tile")]
fn _init_amx() {}
Expand Down Expand Up @@ -324,7 +365,7 @@ mod tests {
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
let mut out = [[1_i8; 64]; 16];
_tile_stored::<0>(&mut out as *mut [i8; 64] as *mut i8, 64);
_tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
_tile_release();
assert_eq!(out, [[0; 64]; 16]);
}
Expand All @@ -339,7 +380,7 @@ mod tests {
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
let mut out = [[1_i8; 64]; 16];
_tile_stored::<0>(&mut out as *mut [i8; 64] as *mut i8, 64);
_tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
_tile_release();
assert_eq!(out, [[0; 64]; 16]);
}
Expand All @@ -354,9 +395,9 @@ mod tests {
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
let mat = [1_i8; 1024];
_tile_loadd::<0>(&mat as *const i8, 64);
_tile_loadd::<0>(&mat as *const i8 as *const u8, 64);
let mut out = [[0_i8; 64]; 16];
_tile_stored::<0>(&mut out as *mut [i8; 64] as *mut i8, 64);
_tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
_tile_release();
assert_eq!(out, [[1; 64]; 16]);
}
Expand All @@ -371,9 +412,9 @@ mod tests {
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
let mat = [1_i8; 1024];
_tile_stream_loadd::<0>(&mat as *const i8, 64);
_tile_stream_loadd::<0>(&mat as *const i8 as *const u8, 64);
let mut out = [[0_i8; 64]; 16];
_tile_stored::<0>(&mut out as *mut [i8; 64] as *mut i8, 64);
_tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
_tile_release();
assert_eq!(out, [[1; 64]; 16]);
}
Expand All @@ -388,8 +429,8 @@ mod tests {
_init_amx();
let bf16_1: u16 = _mm_cvtness_sbh(1.0).to_bits();
let bf16_2: u16 = _mm_cvtness_sbh(2.0).to_bits();
let ones: [i8; 1024] = transmute([bf16_1; 512]);
let twos: [i8; 1024] = transmute([bf16_2; 512]);
let ones: [u8; 1024] = transmute([bf16_1; 512]);
let twos: [u8; 1024] = transmute([bf16_2; 512]);
let mut res = [[0f32; 16]; 16];
let mut config = __tilecfg::default();
config.palette = 1;
Expand All @@ -399,10 +440,10 @@ mod tests {
});
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
_tile_loadd::<1>(&ones as *const i8, 64);
_tile_loadd::<2>(&twos as *const i8, 64);
_tile_loadd::<1>(&ones as *const u8, 64);
_tile_loadd::<2>(&twos as *const u8, 64);
_tile_dpbf16ps::<0, 1, 2>();
_tile_stored::<0>(&mut res as *mut [f32; 16] as *mut i8, 64);
_tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
_tile_release();
assert_eq!(res, [[64f32; 16]; 16]);
}
Expand All @@ -421,10 +462,10 @@ mod tests {
});
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
_tile_loadd::<1>(&ones as *const i8, 64);
_tile_loadd::<2>(&twos as *const i8, 64);
_tile_loadd::<1>(&ones as *const i8 as *const u8, 64);
_tile_loadd::<2>(&twos as *const i8 as *const u8, 64);
_tile_dpbssd::<0, 1, 2>();
_tile_stored::<0>(&mut res as *mut [i32; 16] as *mut i8, 64);
_tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
_tile_release();
assert_eq!(res, [[128_i32; 16]; 16]);
}
Expand All @@ -443,10 +484,10 @@ mod tests {
});
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
_tile_loadd::<1>(&ones as *const i8, 64);
_tile_loadd::<2>(&twos as *const u8 as *const i8, 64);
_tile_loadd::<1>(&ones as *const i8 as *const u8, 64);
_tile_loadd::<2>(&twos as *const u8, 64);
_tile_dpbsud::<0, 1, 2>();
_tile_stored::<0>(&mut res as *mut [i32; 16] as *mut i8, 64);
_tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
_tile_release();
assert_eq!(res, [[-128_i32; 16]; 16]);
}
Expand All @@ -465,10 +506,10 @@ mod tests {
});
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
_tile_loadd::<1>(&ones as *const u8 as *const i8, 64);
_tile_loadd::<2>(&twos as *const i8, 64);
_tile_loadd::<1>(&ones as *const u8, 64);
_tile_loadd::<2>(&twos as *const i8 as *const u8, 64);
_tile_dpbusd::<0, 1, 2>();
_tile_stored::<0>(&mut res as *mut [i32; 16] as *mut i8, 64);
_tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
_tile_release();
assert_eq!(res, [[-128_i32; 16]; 16]);
}
Expand All @@ -487,10 +528,10 @@ mod tests {
});
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
_tile_loadd::<1>(&ones as *const u8 as *const i8, 64);
_tile_loadd::<2>(&twos as *const u8 as *const i8, 64);
_tile_loadd::<1>(&ones as *const u8, 64);
_tile_loadd::<2>(&twos as *const u8, 64);
_tile_dpbuud::<0, 1, 2>();
_tile_stored::<0>(&mut res as *mut [i32; 16] as *mut i8, 64);
_tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
_tile_release();
assert_eq!(res, [[128_i32; 16]; 16]);
}
Expand All @@ -509,10 +550,10 @@ mod tests {
});
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
_tile_loadd::<1>(&ones as *const f16 as *const i8, 64);
_tile_loadd::<2>(&twos as *const f16 as *const i8, 64);
_tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
_tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
_tile_dpfp16ps::<0, 1, 2>();
_tile_stored::<0>(&mut res as *mut [f32; 16] as *mut i8, 64);
_tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
_tile_release();
assert_eq!(res, [[64f32; 16]; 16]);
}
Expand All @@ -531,10 +572,10 @@ mod tests {
});
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
_tile_loadd::<1>(&ones as *const f16 as *const i8, 64);
_tile_loadd::<2>(&twos as *const f16 as *const i8, 64);
_tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
_tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
_tile_cmmimfp16ps::<0, 1, 2>();
_tile_stored::<0>(&mut res as *mut [f32; 16] as *mut i8, 64);
_tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
_tile_release();
assert_eq!(res, [[64f32; 16]; 16]);
}
Expand All @@ -553,10 +594,10 @@ mod tests {
});
_tile_loadconfig(config.as_ptr());
_tile_zero::<0>();
_tile_loadd::<1>(&ones as *const f16 as *const i8, 64);
_tile_loadd::<2>(&twos as *const f16 as *const i8, 64);
_tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
_tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
_tile_cmmrlfp16ps::<0, 1, 2>();
_tile_stored::<0>(&mut res as *mut [f32; 16] as *mut i8, 64);
_tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
_tile_release();
assert_eq!(res, [[0f32; 16]; 16]);
}
Expand Down
43 changes: 0 additions & 43 deletions crates/core_arch/src/x86_64/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,6 @@
#[macro_use]
mod macros;

#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
#[allow(non_camel_case_types)]
#[repr(packed)]
#[derive(Copy, Clone, Default, Debug, PartialEq)]
pub struct __tilecfg {
/// 0 `or` 1
pub palette: u8,
pub start_row: u8,
/// reserved, must be zero
reserved_a0: [u8; 14],
/// number of bytes of one row in each tile
pub colsb: [u16; 8],
/// reserved, must be zero
reserved_b0: [u16; 8],
/// number of rows in each tile
pub rows: [u8; 8],
/// reserved, must be zero
reserved_c0: [u8; 8],
}

#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
impl __tilecfg {
pub fn new(palette: u8, start_row: u8, colsb: [u16; 8], rows: [u8; 8]) -> Self {
Self {
palette,
start_row,
reserved_a0: [0u8; 14],
colsb,
reserved_b0: [0u16; 8],
rows,
reserved_c0: [0u8; 8],
}
}

pub const fn as_ptr(&self) -> *const i8 {
self as *const Self as *const i8
}

pub fn as_mut_ptr(&mut self) -> *mut i8 {
self as *mut Self as *mut i8
}
}

mod fxsr;
#[stable(feature = "simd_x86", since = "1.27.0")]
pub use self::fxsr::*;
Expand Down

0 comments on commit d3d49fb

Please sign in to comment.