diff --git a/crates/core_arch/src/x86_64/amx.rs b/crates/core_arch/src/x86_64/amx.rs index 2eea9e0e8a..0d161f9347 100644 --- a/crates/core_arch/src/x86_64/amx.rs +++ b/crates/core_arch/src/x86_64/amx.rs @@ -8,7 +8,7 @@ #[inline] #[target_feature(enable = "amx-tile")] #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] -pub unsafe fn _tile_loadconfig(mem_addr: *const i8) { +pub unsafe fn _tile_loadconfig(mem_addr: *const u8) { ldtilecfg(mem_addr); } @@ -20,7 +20,7 @@ pub unsafe fn _tile_loadconfig(mem_addr: *const i8) { #[inline] #[target_feature(enable = "amx-tile")] #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] -pub unsafe fn _tile_storeconfig(mem_addr: *mut i8) { +pub unsafe fn _tile_storeconfig(mem_addr: *mut u8) { sttilecfg(mem_addr); } @@ -31,7 +31,7 @@ pub unsafe fn _tile_storeconfig(mem_addr: *mut i8) { #[rustc_legacy_const_generics(0)] #[target_feature(enable = "amx-tile")] #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] -pub unsafe fn _tile_loadd(base: *const i8, stride: usize) { +pub unsafe fn _tile_loadd(base: *const u8, stride: usize) { static_assert_uimm_bits!(DST, 3); tileloadd64(DST, base, stride); } @@ -53,7 +53,7 @@ pub unsafe fn _tile_release() { #[rustc_legacy_const_generics(0)] #[target_feature(enable = "amx-tile")] #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] -pub unsafe fn _tile_stored(base: *mut i8, stride: usize) { +pub unsafe fn _tile_stored(base: *mut u8, stride: usize) { static_assert_uimm_bits!(DST, 3); tilestored64(DST, base, stride); } @@ -67,7 +67,7 @@ pub unsafe fn _tile_stored(base: *mut i8, stride: usize) { #[rustc_legacy_const_generics(0)] #[target_feature(enable = "amx-tile")] #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] -pub unsafe fn _tile_stream_loadd(base: *const i8, stride: usize) { +pub unsafe fn _tile_stream_loadd(base: *const u8, stride: usize) { static_assert_uimm_bits!(DST, 3); tileloaddt164(DST, base, stride); } @@ -227,17 +227,17 @@ pub unsafe fn _tile_cmmrlfp16ps() { #[allow(improper_ctypes)] extern "C" { #[link_name = "llvm.x86.ldtilecfg"] - fn ldtilecfg(mem_addr: *const i8); + fn ldtilecfg(mem_addr: *const u8); #[link_name = "llvm.x86.sttilecfg"] - fn sttilecfg(mem_addr: *mut i8); + fn sttilecfg(mem_addr: *mut u8); #[link_name = "llvm.x86.tileloadd64"] - fn tileloadd64(dst: i8, base: *const i8, stride: usize); + fn tileloadd64(dst: i8, base: *const u8, stride: usize); #[link_name = "llvm.x86.tileloaddt164"] - fn tileloaddt164(dst: i8, base: *const i8, stride: usize); + fn tileloaddt164(dst: i8, base: *const u8, stride: usize); #[link_name = "llvm.x86.tilerelease"] fn tilerelease(); #[link_name = "llvm.x86.tilestored64"] - fn tilestored64(dst: i8, base: *mut i8, stride: usize); + fn tilestored64(dst: i8, base: *mut u8, stride: usize); #[link_name = "llvm.x86.tilezero"] fn tilezero(dst: i8); #[link_name = "llvm.x86.tdpbf16ps"] @@ -267,6 +267,47 @@ mod tests { #[cfg(target_os = "linux")] use syscalls::{syscall, Sysno}; + #[allow(non_camel_case_types)] + #[repr(packed)] + #[derive(Copy, Clone, Default, Debug, PartialEq)] + struct __tilecfg { + /// 0 `or` 1 + palette: u8, + start_row: u8, + /// reserved, must be zero + reserved_a0: [u8; 14], + /// number of bytes of one row in each tile + colsb: [u16; 8], + /// reserved, must be zero + reserved_b0: [u16; 8], + /// number of rows in each tile + rows: [u8; 8], + /// reserved, must be zero + reserved_c0: [u8; 8], + } + + impl __tilecfg { + fn new(palette: u8, start_row: u8, colsb: [u16; 8], rows: [u8; 8]) -> Self { + Self { + palette, + start_row, + reserved_a0: [0u8; 14], + colsb, + reserved_b0: [0u16; 8], + rows, + reserved_c0: [0u8; 8], + } + } + + const fn as_ptr(&self) -> *const u8 { + self as *const Self as *const u8 + } + + fn as_mut_ptr(&mut self) -> *mut u8 { + self as *mut Self as *mut u8 + } + } + #[cfg(not(target_os = "linux"))] #[target_feature(enable = "amx-tile")] fn _init_amx() {} @@ -324,7 +365,7 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mut out = [[1_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut i8, 64); + _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); _tile_release(); assert_eq!(out, [[0; 64]; 16]); } @@ -339,7 +380,7 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mut out = [[1_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut i8, 64); + _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); _tile_release(); assert_eq!(out, [[0; 64]; 16]); } @@ -354,9 +395,9 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mat = [1_i8; 1024]; - _tile_loadd::<0>(&mat as *const i8, 64); + _tile_loadd::<0>(&mat as *const i8 as *const u8, 64); let mut out = [[0_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut i8, 64); + _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); _tile_release(); assert_eq!(out, [[1; 64]; 16]); } @@ -371,9 +412,9 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mat = [1_i8; 1024]; - _tile_stream_loadd::<0>(&mat as *const i8, 64); + _tile_stream_loadd::<0>(&mat as *const i8 as *const u8, 64); let mut out = [[0_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut i8, 64); + _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); _tile_release(); assert_eq!(out, [[1; 64]; 16]); } @@ -388,8 +429,8 @@ mod tests { _init_amx(); let bf16_1: u16 = _mm_cvtness_sbh(1.0).to_bits(); let bf16_2: u16 = _mm_cvtness_sbh(2.0).to_bits(); - let ones: [i8; 1024] = transmute([bf16_1; 512]); - let twos: [i8; 1024] = transmute([bf16_2; 512]); + let ones: [u8; 1024] = transmute([bf16_1; 512]); + let twos: [u8; 1024] = transmute([bf16_2; 512]); let mut res = [[0f32; 16]; 16]; let mut config = __tilecfg::default(); config.palette = 1; @@ -399,10 +440,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const i8, 64); - _tile_loadd::<2>(&twos as *const i8, 64); + _tile_loadd::<1>(&ones as *const u8, 64); + _tile_loadd::<2>(&twos as *const u8, 64); _tile_dpbf16ps::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut i8, 64); + _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); _tile_release(); assert_eq!(res, [[64f32; 16]; 16]); } @@ -421,10 +462,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const i8, 64); - _tile_loadd::<2>(&twos as *const i8, 64); + _tile_loadd::<1>(&ones as *const i8 as *const u8, 64); + _tile_loadd::<2>(&twos as *const i8 as *const u8, 64); _tile_dpbssd::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut i8, 64); + _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); _tile_release(); assert_eq!(res, [[128_i32; 16]; 16]); } @@ -443,10 +484,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const i8, 64); - _tile_loadd::<2>(&twos as *const u8 as *const i8, 64); + _tile_loadd::<1>(&ones as *const i8 as *const u8, 64); + _tile_loadd::<2>(&twos as *const u8, 64); _tile_dpbsud::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut i8, 64); + _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); _tile_release(); assert_eq!(res, [[-128_i32; 16]; 16]); } @@ -465,10 +506,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8 as *const i8, 64); - _tile_loadd::<2>(&twos as *const i8, 64); + _tile_loadd::<1>(&ones as *const u8, 64); + _tile_loadd::<2>(&twos as *const i8 as *const u8, 64); _tile_dpbusd::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut i8, 64); + _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); _tile_release(); assert_eq!(res, [[-128_i32; 16]; 16]); } @@ -487,10 +528,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8 as *const i8, 64); - _tile_loadd::<2>(&twos as *const u8 as *const i8, 64); + _tile_loadd::<1>(&ones as *const u8, 64); + _tile_loadd::<2>(&twos as *const u8, 64); _tile_dpbuud::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut i8, 64); + _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); _tile_release(); assert_eq!(res, [[128_i32; 16]; 16]); } @@ -509,10 +550,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const f16 as *const i8, 64); - _tile_loadd::<2>(&twos as *const f16 as *const i8, 64); + _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); + _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); _tile_dpfp16ps::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut i8, 64); + _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); _tile_release(); assert_eq!(res, [[64f32; 16]; 16]); } @@ -531,10 +572,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const f16 as *const i8, 64); - _tile_loadd::<2>(&twos as *const f16 as *const i8, 64); + _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); + _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); _tile_cmmimfp16ps::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut i8, 64); + _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); _tile_release(); assert_eq!(res, [[64f32; 16]; 16]); } @@ -553,10 +594,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const f16 as *const i8, 64); - _tile_loadd::<2>(&twos as *const f16 as *const i8, 64); + _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); + _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); _tile_cmmrlfp16ps::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut i8, 64); + _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); _tile_release(); assert_eq!(res, [[0f32; 16]; 16]); } diff --git a/crates/core_arch/src/x86_64/mod.rs b/crates/core_arch/src/x86_64/mod.rs index 4bd3499f60..32ebf87d9c 100644 --- a/crates/core_arch/src/x86_64/mod.rs +++ b/crates/core_arch/src/x86_64/mod.rs @@ -3,49 +3,6 @@ #[macro_use] mod macros; -#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] -#[allow(non_camel_case_types)] -#[repr(packed)] -#[derive(Copy, Clone, Default, Debug, PartialEq)] -pub struct __tilecfg { - /// 0 `or` 1 - pub palette: u8, - pub start_row: u8, - /// reserved, must be zero - reserved_a0: [u8; 14], - /// number of bytes of one row in each tile - pub colsb: [u16; 8], - /// reserved, must be zero - reserved_b0: [u16; 8], - /// number of rows in each tile - pub rows: [u8; 8], - /// reserved, must be zero - reserved_c0: [u8; 8], -} - -#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] -impl __tilecfg { - pub fn new(palette: u8, start_row: u8, colsb: [u16; 8], rows: [u8; 8]) -> Self { - Self { - palette, - start_row, - reserved_a0: [0u8; 14], - colsb, - reserved_b0: [0u16; 8], - rows, - reserved_c0: [0u8; 8], - } - } - - pub const fn as_ptr(&self) -> *const i8 { - self as *const Self as *const i8 - } - - pub fn as_mut_ptr(&mut self) -> *mut i8 { - self as *mut Self as *mut i8 - } -} - mod fxsr; #[stable(feature = "simd_x86", since = "1.27.0")] pub use self::fxsr::*;