From 0692435656876731cf2926a5325743bd1adcf966 Mon Sep 17 00:00:00 2001 From: taozha2 Date: Sat, 18 Jan 2025 00:51:12 +0800 Subject: [PATCH] Implement full feature of copy/gemm for PVC backend (#174) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Implement full feature of copy/gemm for PVC backend Implement Feature: 1. Implement full features of copy/MMA for PVC backend We don't implement full copy/gemm functions before this commit because the cutlass cute copy/MMA API is not fully compatible with PVC backend. The register layout loaded by PVC subgroup intrinsic doesn't satisfy the cute::gemm requirement which leads to problems including but not limited to: (1) GEMM can only support specific combination of tile sizes and copy traits. GEMM functionality will be wrong if you try to change tile size configuration or copy traits. For example, the case "examples/sycl/pvc/pvc_gemm.cpp" will fail if you change sg_tile_k from 32 to 64. So we must retile the register data layout before cute::gemm. (2) We have to hardcode to change the register layout to satisfy the requirement of cutlass cute APIs. For example the data from “partition_fragment_B” need to be hardcoded. 2. Support different GEMM layout and data type (1) Support different combinations of RowMajor and ColumnMajor for matrix A and B. Refer to test/unit/cute/intel_xe/gemm_data_type.cpp. (2) Add GEMM test case for int8/uint8/fp16/bf16. Refer to test/unit/cute/intel_xe/gemm_layout.cpp. This PR will implement above features and keep performance not dropped. Refine Code 1. Refine layout convention for gemm. For GEMM C = A x B; let A is (m, k, l), B is (n, k, l), C is (m, n, l), hide backend related differences inside implementation of PVC copy traits(copy_traits_xe.hpp), make it easier for upper-level users to write code for Intel Xe GPU according to cutlass usage habits, don’t let user hardcode for Intel Xe GPU. 2. Refine the API "get_pvc_tensor" Before this PR, we mix K-slicing and coordinate tensor together, which make the interface parameters unclear and difficult to understand. actuualy "K-slicing" is for MMA use, while "coordinate tensor" is only for copy, they are two things, we must keep them functionally independent, so we supply a helper function "append_pvc_tensor". * misc refine * Update copy_traits_xe.hpp * use make_coord * rename variable to make semantics clear * enable tf32 gemm and some refactoring * refine code * fix some comments * fix comments * fix comments * fix build error * refine gemm, add retile_MMA API for xe * fix build error * fix flash atten build issue * update * fix flash attention validation issue * update * update * update benchmark configurations * fix xe visit bug * fix prefetch issue * fix validation issue of uint test * Update test/unit/cute/intel_xe/utils.hpp Co-authored-by: Joe Todd --------- Co-authored-by: Alejandro Acosta Co-authored-by: Joe Todd --- benchmarks/pvc/benchmarks.hpp | 36 ++ benchmarks/pvc/input.in | 3 + .../pvc/flash_attention_v2/pvc_flash_attn.cpp | 2 +- .../pvc_flash_attn_epilogue.hpp | 23 +- .../pvc_flash_attn_gemm_universal.hpp | 32 +- .../flash_attention_v2/pvc_flash_attn_mma.hpp | 102 ++-- include/cute/arch/mma_xe.hpp | 16 +- include/cute/arch/xe_copy_1B.hpp | 34 ++ include/cute/arch/xe_copy_2B.hpp | 47 ++ include/cute/arch/xe_copy_4B.hpp | 68 +++ include/cute/arch/xe_copy_8B.hpp | 6 + include/cute/atom/copy_traits_xe.hpp | 504 ++++++++++++------ .../epilogue/collective/xe_epilogue.hpp | 33 +- .../cutlass/epilogue/fusion/xe_visitor.hpp | 18 +- include/cutlass/gemm/collective/xe_mma.hpp | 154 +++--- include/cutlass/gemm/kernel/xe_gemm.hpp | 31 +- test/unit/cute/intel_xe/CMakeLists.txt | 5 +- test/unit/cute/intel_xe/copy_1d.cpp | 17 +- test/unit/cute/intel_xe/copy_block.cpp | 153 +++--- test/unit/cute/intel_xe/copy_scatter.cpp | 43 +- .../cute/intel_xe/copy_subgroup_block.cpp | 47 +- test/unit/cute/intel_xe/gemm_col_col.cpp | 237 -------- test/unit/cute/intel_xe/gemm_col_row.cpp | 236 -------- test/unit/cute/intel_xe/gemm_common.hpp | 203 +++++++ test/unit/cute/intel_xe/gemm_data_type.cpp | 85 +++ test/unit/cute/intel_xe/gemm_layout.cpp | 69 +++ .../intel_xe/gemm_partition_fragment_abc.cpp | 155 +----- .../cute/intel_xe/gemm_partition_src_dst.cpp | 129 +++-- test/unit/cute/intel_xe/gemm_row_col.cpp | 238 --------- .../cute/intel_xe/gemm_tiled_copy_abc.cpp | 109 ++-- test/unit/cute/intel_xe/mma.cpp | 133 ++--- .../intel_xe/{gemm_utils.hpp => utils.hpp} | 41 +- 32 files changed, 1469 insertions(+), 1540 deletions(-) delete mode 100644 test/unit/cute/intel_xe/gemm_col_col.cpp delete mode 100644 test/unit/cute/intel_xe/gemm_col_row.cpp create mode 100755 test/unit/cute/intel_xe/gemm_common.hpp create mode 100755 test/unit/cute/intel_xe/gemm_data_type.cpp create mode 100755 test/unit/cute/intel_xe/gemm_layout.cpp delete mode 100755 test/unit/cute/intel_xe/gemm_row_col.cpp rename test/unit/cute/intel_xe/{gemm_utils.hpp => utils.hpp} (78%) mode change 100644 => 100755 diff --git a/benchmarks/pvc/benchmarks.hpp b/benchmarks/pvc/benchmarks.hpp index bd06c2566..d9c2c4b96 100644 --- a/benchmarks/pvc/benchmarks.hpp +++ b/benchmarks/pvc/benchmarks.hpp @@ -87,11 +87,44 @@ using PvcGemmBF16BF16FP32_RRR_5 = cutlass::gemm::device::GemmConfiguration< XE_2D_U16x8x32_LD_N, XE_2D_U16x32x32_LD_V, Scheduler::Gemm>; +using PvcGemmBF16BF16FP32_RCR_6 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::IntelPVC, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, + float, cutlass::layout::RowMajor, + float, Shape<_256, _256, _32>, + TiledMMA>>, + XE_2D_U16x8x32_LD_N, XE_2D_U16x16x16_LD_T, + Scheduler::Gemm>; + +using PvcGemmBF16BF16FP32_CRR_7 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::IntelPVC, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, Shape<_256, _256, _32>, + TiledMMA>>, + XE_2D_U16x16x16_LD_T, XE_2D_U16x32x32_LD_V, + Scheduler::Gemm>; + +using PvcGemmBF16BF16FP32_CCR_8 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::IntelPVC, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, + float, cutlass::layout::RowMajor, + float, Shape<_256, _256, _32>, + TiledMMA>>, + XE_2D_U16x16x16_LD_T, XE_2D_U16x16x16_LD_T, + Scheduler::Gemm>; + CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_1); CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_2); CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_3); CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_4); CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_5); +CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RCR_6); +CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_CRR_7); +CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_CCR_8); using PvcGemmBF16BF16FP32_StreamK_RRR_1 = cutlass::gemm::device::GemmConfiguration< cutlass::arch::IntelPVC, @@ -123,6 +156,9 @@ static void register_benchmarks() { CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_3); CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_4); CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_5); + CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RCR_6); + CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_CRR_7); + CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_CCR_8); CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_StreamK_RRR_1); CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_SplitK_RRR_1); } diff --git a/benchmarks/pvc/input.in b/benchmarks/pvc/input.in index 2d1074e97..63e71eeeb 100644 --- a/benchmarks/pvc/input.in +++ b/benchmarks/pvc/input.in @@ -21,6 +21,9 @@ PvcGemmBF16BF16FP32_RRR_5 --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=16384 --n= PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=128 --n=4096 PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=4096 --n=128 PvcGemmBF16BF16FP32_RRR_3 --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=128 +PvcGemmBF16BF16FP32_RCR_6 --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=4096 --n=4096 +PvcGemmBF16BF16FP32_CRR_7 --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=4096 --n=4096 +PvcGemmBF16BF16FP32_CCR_8 --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=4096 --n=4096 PvcGemmBF16BF16FP32_StreamK_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=8192 PvcGemmBF16BF16FP32_StreamK_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=32768 diff --git a/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn.cpp b/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn.cpp index fffb6ee32..c9c74406c 100644 --- a/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn.cpp +++ b/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn.cpp @@ -299,7 +299,7 @@ struct ExampleRunner { stride_Q = cutlass::make_cute_packed_stride(StrideQ{}, cute::make_shape(seq_len, head_size, batch * num_heads)); stride_K = cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len, head_size, batch * num_heads)); - stride_V = cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(seq_len, head_size, batch * num_heads)); + stride_V = cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size, seq_len, batch * num_heads)); stride_O = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len, head_size, batch * num_heads)); stride_LSE = cutlass::make_cute_packed_stride(StrideLSE{}, cute::make_shape(seq_len, 1, batch * num_heads)); diff --git a/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_epilogue.hpp b/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_epilogue.hpp index f5233729d..65b992e3e 100644 --- a/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_epilogue.hpp +++ b/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_epilogue.hpp @@ -104,11 +104,9 @@ class CollectiveEpilogueAttention< static_assert(cute::rank(StrideLSE{}) == 3, "StrideLSE must be rank-3: [batch, num_heads, seq_len]"); using Trait_O = Copy_Traits; - using XE_Copy_O = decltype(make_tiled_copy(Copy_Atom{} - .with(static_cast(nullptr), int32_t(0), int32_t(0), int32_t(0)), - Layout>>{}, - make_layout(make_shape(get<0>(typename Trait_O::Shape_MN{}), - get<1>(typename Trait_O::Shape_MN{}) / Int{})))); + using XE_Copy_O = decltype(make_xe_2d_copy(Copy_Atom, ElementO>{}.with( + make_tensor(make_gmem_ptr(static_cast(nullptr)), make_layout(make_shape(0, 0, 0), StrideO{}))), + Layout>>{})); private: constexpr static bool is_destination_supported = not cute::is_void_v; @@ -157,11 +155,9 @@ class CollectiveEpilogueAttention< auto [batch, num_heads, seq_len, head_size] = problem_shape; XE_Copy_O xe_store_o = {}; - xe_store_o = make_tiled_copy(Copy_Atom, ElementO>{}.with( - args.ptr_O, head_size, seq_len, head_size), - Layout>>{}, - make_layout(make_shape(get<0>(typename Trait_O::Shape_MN{}), - get<1>(typename Trait_O::Shape_MN{}) / Int{}))); + xe_store_o = make_xe_2d_copy(Copy_Atom, ElementO>{}.with( + make_tensor(make_gmem_ptr(static_cast(args.ptr_O)), make_layout(make_shape(seq_len, head_size, batch * num_heads), args.dO))), + Layout>>{}); return { FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), @@ -263,11 +259,10 @@ class CollectiveEpilogueAttention< auto [batch, num_heads, seq_len, head_size] = problem_shape; Tensor tOi = params.xe_store_o.get_pvc_tensor( - make_coord(m_offset, n_offset, 0), - make_shape(_, Int{}, Int{}, batch * num_heads), - make_stride(Int(MmaAtomShape{})>{}, Int(MmaAtomShape{})>{}, _1{})); + make_coord(m_offset, n_offset, l_coord), + make_shape(_, Int{}, Int{})); - copy(params.xe_store_o, out, tOi(_,_,_,l_coord)); + copy(params.xe_store_o, out, tOi); const int lse_offset = m_offset + l_coord * seq_len; diff --git a/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_gemm_universal.hpp b/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_gemm_universal.hpp index b1c348722..07cb486f5 100644 --- a/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_gemm_universal.hpp +++ b/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_gemm_universal.hpp @@ -294,37 +294,45 @@ class GemmUniversalAttention const int item_id = thread_idx % SubgroupSize; const int k_tile_count= head_size / get<1>(subgroup_shape); //m, k - Tensor prefetch_iter_a = params.mainloop.gmem_prefetch_q.get_pvc_tensor( + Tensor prefetch_iter_2d_a = params.mainloop.gmem_prefetch_q.get_pvc_tensor( make_coord(seq_coord + (((sub_group_id % ATOM_N) / get<1>(PrefetchQThrShape{}))* get<0>(PrefetchQTileSize{})), // iteration 0/M/Hight/vertical - (((sub_group_id % ATOM_N) % get<1>(PrefetchQThrShape{})) * get<1>(PrefetchQTileSize{})), // Iteration 1/K/Width/Horisontal - blk_l_coord), - append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count), - append<3>(make_shape(_, _), BLK_K), seq<0, 0, 1>{}); + ((sub_group_id % ATOM_N) % get<1>(PrefetchQThrShape{})) * get<1>(PrefetchQTileSize{}), // Iteration 1/K/Width/Horisontal + blk_l_coord), + make_shape(_1{}, _1{}, _1{})); + Tensor prefetch_iter_a = append_pvc_tensor<1>(prefetch_iter_2d_a, k_tile_count, BLK_K); + // append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count), + // append<3>(make_shape(_, _), BLK_K), seq<0, 0, 1>{}); // The Key point is 1 is horisontal and zero is vertical // the iteration over K dimention of B matrix (head_size) should be : auto iter_over_head_count = head_size / BLK_N; // k, n - Tensor prefetch_iter_b = params.mainloop.gmem_prefetch_k.get_pvc_tensor( + Tensor prefetch_iter_2d_b = params.mainloop.gmem_prefetch_k.get_pvc_tensor( make_coord(sub_group_id * get<0>(PrefetchKTileSize{}), // iteration 0/K/Hight/vertical (sub_group_id % ATOM_N) * get<1>(PrefetchKTileSize{}), // iteration 1/N/W/Horisontal blk_l_coord), // batch // ?, ?, k, N swap k and n here to match cutlass - append<4>(make_shape(_1{}, _1{}, nblock_limit/*This is N*/), iter_over_head_count/* This is K*/), //(frag, iter_m, iter_n, iter_k) + make_shape(_1{}, _1{}, nblock_limit/*This is N*/)); + // iter_over_head_count/* This is K*/), //(frag, iter_m, iter_n, iter_k) // K, ?, N (The N should move along the N as get<0>(PrefetchKThrShape) load 32 each and we want 128 of N ) // The K should move along the dimmension of Block load as we lay 8x32 using the 8x1 shape for subgroups // leading to load 64x32 of (K,N) per each prefetch (BLOCK_N SHows K DIM) - append<3>(make_shape(_, SG_N), BLK_N), seq<0, 1, 0>{}); // so 64 * iteration 0 (SG_N that is K which is vertical) and 32 * iteration 1 (N which is horisontal) + // append<3>(make_shape(_, SG_N), BLK_N), seq<0, 1, 0>{}); // so 64 * iteration 0 (SG_N that is K which is vertical) and 32 * iteration 1 (N which is horisontal) // V is a transposed matrix, So here the Sequense length is consumed, it is transposed so the consumed dimension looks like B matrix // Hence, the Head size is the fast moving dimention and horisontal and sequence length is vertical. // The prefetch only move along the sequence lenth. Here we call sequence length K since it get consumed and head size N since it stay - Tensor prefetch_iter_v = params.mainloop.gmem_prefetch_v.get_pvc_tensor( + + Tensor prefetch_iter_b = append_pvc_tensor<0>(prefetch_iter_2d_b, iter_over_head_count, BLK_N); + + Tensor prefetch_iter_2d_v = params.mainloop.gmem_prefetch_v.get_pvc_tensor( make_coord((sub_group_id / ATOM_N) * get<0>(PrefetchVTileSize{}), // iteration 0/K/Hight/vertical/ sequence lengh - head_size_coord, // iteration 1/N/W/Horisontal / Head size + head_size_coord, // iteration 1/N/W/Horisontal / Head size blk_l_coord), // We loop over the consuming dimension which is the iteration 0(N) here - append<4>(make_shape(_1{}, _1{}, _1{}), nblock_limit), + make_shape(_1{}, _1{}, _1{})); + // , nblock_limit), // first one is to use the intrinsic along the vertical , Second one is N/M and third one is K - append<3>(make_shape(_, _), BLK_K), seq<0, 1, 0>{}); + // append<3>(make_shape(_, _), BLK_K), seq<0, 1, 0>{}); + Tensor prefetch_iter_v = append_pvc_tensor<0>(prefetch_iter_2d_v, nblock_limit, BLK_K); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < k_tile_count; i++) { diff --git a/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_mma.hpp b/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_mma.hpp index 8bfa01f01..856784c4c 100644 --- a/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_mma.hpp +++ b/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_mma.hpp @@ -142,28 +142,22 @@ struct CollectiveMmaAttention< using PrefetchVTileSize = decltype(ceil_div(Shape, Int>{},PrefetchVThrShape{})); // 8x32 static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); - using traits_load_Q = Copy_Traits; + using traits_load_Q = Copy_Traits; using atom_load_Q = Copy_Atom; - using XE_Copy_Q = decltype(make_tiled_copy(atom_load_Q{} - .with(static_cast(nullptr), int32_t(0), int32_t(0), int32_t(0)), - Layout>>{}, - make_layout(make_shape(get<0>(typename traits_load_Q::Shape_MN{}), - get<1>(typename traits_load_Q::Shape_MN{}) / Int{})))); - using traits_load_K = Copy_Traits; + using XE_Copy_Q = decltype(make_xe_2d_copy(atom_load_Q{} + .with(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_layout(make_shape(0, 0, 0), StrideQ{}))), + Layout>>{})); + using traits_load_K = Copy_Traits; using atom_load_K = Copy_Atom; - using XE_Copy_K = decltype(make_tiled_copy(atom_load_K{} - .with(static_cast(nullptr), int32_t(0), int32_t(0), int32_t(0)), - Layout>>{}, - make_layout(make_shape(get<0>(typename traits_load_K::Shape_MN{}), - get<1>(typename traits_load_K::Shape_MN{}) / Int{})))); + using XE_Copy_K = decltype(make_xe_2d_copy(atom_load_K{} + .with(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_layout(make_shape(0, 0, 0), StrideK{}))), + Layout>>{})); - using traits_load_V = Copy_Traits; + using traits_load_V = Copy_Traits; using atom_load_V = Copy_Atom; - using XE_Copy_V = decltype(make_tiled_copy(atom_load_V{} - .with(static_cast(nullptr), int32_t(0), int32_t(0), int32_t(0)), - Layout>>{}, - make_layout(make_shape(get<0>(typename traits_load_K::Shape_MN{}), - get<1>(typename traits_load_K::Shape_MN{}) / Int{})))); + using XE_Copy_V = decltype(make_xe_2d_copy(atom_load_V{} + .with(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_layout(make_shape(0, 0, 0), StrideV{}))), + Layout>>{})); using XE_Prefetch_Q = decltype(cute::detail::prefetch_selector()); using XE_Prefetch_K = decltype(cute::detail::prefetch_selector()); @@ -202,19 +196,15 @@ struct CollectiveMmaAttention< auto [batch, num_heads, seq_len, head_size] = problem_shape; - XE_Copy_Q copyQ = make_tiled_copy(Copy_Atom, ElementQ>{}.with(args.ptr_Q, head_size, seq_len, head_size), - Layout>>{}, - make_layout(make_shape(get<0>(typename traits_load_Q::Shape_MN{}), - get<1>(typename traits_load_Q::Shape_MN{}) / Int{}))); - XE_Copy_K copyK = make_tiled_copy(Copy_Atom, ElementK>{}.with(args.ptr_K, seq_len, head_size, seq_len), - Layout>>{}, - make_layout(make_shape(get<0>(typename traits_load_K::Shape_MN{}), - get<1>(typename traits_load_K::Shape_MN{}) / Int{}))); - - XE_Copy_V copyV = make_tiled_copy(Copy_Atom, ElementV>{}.with(args.ptr_V, head_size, seq_len, head_size), - Layout>>{}, - make_layout(make_shape(get<0>(typename traits_load_V::Shape_MN{}), - get<1>(typename traits_load_V::Shape_MN{}) / Int{}))); + XE_Copy_Q copyQ = make_xe_2d_copy(Copy_Atom, ElementQ>{}.with( + make_tensor(make_gmem_ptr(static_cast(args.ptr_Q)), make_layout(make_shape(seq_len, head_size, batch * num_heads), args.dQ))), + Layout>>{}); + XE_Copy_K copyK = make_xe_2d_copy(Copy_Atom, ElementK>{}.with( + make_tensor(make_gmem_ptr(static_cast(args.ptr_K)), make_layout(make_shape(seq_len, head_size, batch * num_heads), args.dK))), + Layout>>{}); + XE_Copy_V copyV = make_xe_2d_copy(Copy_Atom, ElementV>{}.with( + make_tensor(make_gmem_ptr(static_cast(args.ptr_V)), make_layout(make_shape(head_size, seq_len, batch * num_heads), args.dV))), + Layout>>{}); XE_Prefetch_Q prefetchQ = cute::detail::prefetch_selector((void *)args.ptr_Q, head_size, seq_len, head_size); XE_Prefetch_K prefetchK = cute::detail::prefetch_selector((void *)args.ptr_K, seq_len, head_size, seq_len); @@ -244,19 +234,16 @@ struct CollectiveMmaAttention< TiledMma tiled_mma; auto thread_mma = tiled_mma.get_slice(thread_idx); Tensor tCrA_partition = thread_mma.partition_fragment_A(gA(_, _, 0)); - Tensor tCrA = make_tensor(static_cast(tCrA_partition).data(), - tCrA_partition.shape()); Tensor tCrB_partition = thread_mma.partition_fragment_B(gB(_, _, 0)); - Tensor tCrB = make_tensor(static_cast(tCrB_partition).data(), - make_shape(size<0>(tCrB_partition.shape()), - size<2>(tCrB_partition.shape()), - size<1>(tCrB_partition.shape()))); // Partition the copying of A and B tiles across the threads auto gmem_thr_copy_A = params.gmem_tiled_copy_q.get_slice(thread_idx); auto gmem_thr_copy_B = params.gmem_tiled_copy_k.get_slice(thread_idx); - auto tCrA_copy_view = gmem_thr_copy_A.retile_D(tCrA); - auto tCrB_copy_view = gmem_thr_copy_B.retile_D(tCrB); + auto tCrA_copy_view = gmem_thr_copy_A.retile_D(tCrA_partition); + auto tCrB_copy_view = gmem_thr_copy_B.retile_D(tCrB_partition); + + Tensor tCrA = gmem_thr_copy_A.retile_MMA(thread_mma, tCrA_partition); + Tensor tCrB = gmem_thr_copy_B.retile_MMA(thread_mma, tCrB_partition); #if CUTLASS_ENABLE_DEBUG_PRINTS if (thread(LOG_THREAD, LOG_GROUP)) { @@ -286,21 +273,19 @@ struct CollectiveMmaAttention< // int sub_group_id = get_sub_group_id(); auto [m_coord, n_coord, k_coord, l_coord] = tile_coord; - Tensor iter_a = params.gmem_tiled_copy_q.get_pvc_tensor( - make_coord(m_coord, 0, l_coord), append<4>(tCrA_copy_view.shape(), k_tile_count), - append<3>(typename XE_Copy_Q::Shape_MN{}, BLK_K), seq<0,1,1>{}); - Tensor iter_b = params.gmem_tiled_copy_k.get_pvc_tensor( - make_coord(0, n_coord, l_coord), append<4>(tCrB_copy_view.shape(), k_tile_count), - append<3>(typename XE_Copy_K::Shape_MN{}, BLK_K), seq<0,1,0>{}); + Tensor iter_2d_a = params.gmem_tiled_copy_q.get_pvc_tensor( + make_coord(m_coord, 0, l_coord), tCrA_copy_view.shape()); + Tensor iter_a = append_pvc_tensor<1>(iter_2d_a, k_tile_count, BLK_K); + Tensor iter_2d_b = params.gmem_tiled_copy_k.get_pvc_tensor( + make_coord(n_coord, 0, l_coord), tCrB_copy_view.shape()); + Tensor iter_b = append_pvc_tensor<1>(iter_2d_b, k_tile_count, BLK_K); CUTLASS_PRAGMA_UNROLL for (int k_tile = 0; k_tile < k_tile_count; ++k_tile) { // Copy gmem to rmem for the first k_tile copy(params.gmem_tiled_copy_q, iter_a(_,_,_,k_tile), tCrA_copy_view); copy(params.gmem_tiled_copy_k, iter_b(_,_,_,k_tile), tCrB_copy_view); - for (int i = 0; i < SG_K / SubgroupSize; i++) { - cute::gemm(tiled_mma, accum, tCrA(_, _, i), tCrB(_, i, _), frag_src); - } + cute::gemm(tiled_mma, accum, tCrA, tCrB, frag_src); } } @@ -328,14 +313,12 @@ struct CollectiveMmaAttention< auto thread_mma = tiled_mma.get_slice(thread_idx); Tensor tCrB_partition = thread_mma.partition_fragment_B(gB(_, _, 0)); - Tensor tCrB = make_tensor(static_cast(tCrB_partition).data(), - make_shape(size<0>(tCrB_partition.shape()), - size<2>(tCrB_partition.shape()), - size<1>(tCrB_partition.shape()))); // Partition the copying of A and B tiles across the threads - auto gmem_thr_copy_B = params.gmem_tiled_copy_k.get_slice(thread_idx); + auto gmem_thr_copy_B = params.gmem_tiled_copy_v.get_slice(thread_idx); - auto tCrB_copy_view = gmem_thr_copy_B.retile_D(tCrB); + auto tCrB_copy_view = gmem_thr_copy_B.retile_D(tCrB_partition); + + Tensor tCrB = gmem_thr_copy_B.retile_MMA(thread_mma, tCrB_partition); #if CUTLASS_ENABLE_DEBUG_PRINTS if (thread(LOG_THREAD, LOG_GROUP)) { @@ -359,16 +342,13 @@ struct CollectiveMmaAttention< int sub_group_id = get_sub_group_id(); auto [m_coord, n_coord, k_coord, l_coord] = tile_coord; - Tensor iter_b = params.gmem_tiled_copy_v.get_pvc_tensor( - make_coord(0, n_coord, l_coord), append<4>(tCrB_copy_view.shape(), k_tile_count), - append<3>(typename XE_Copy_K::Shape_MN{}, BLK_K), seq<0,1,0>{}); + Tensor iter_2d_b = params.gmem_tiled_copy_v.get_pvc_tensor( + make_coord(n_coord, 0, l_coord), tCrB_copy_view.shape()); + Tensor iter_b = append_pvc_tensor<1>(iter_2d_b, k_tile_count, BLK_K); copy(params.gmem_tiled_copy_v, iter_b(_,_,_, load_idx), tCrB_copy_view); - for (int i = 0; i < SG_K / SubgroupSize; i++) { - cute::gemm(tiled_mma, accum, tPr(_, _, i), tCrB(_, i, _), frag_src); - } - + cute::gemm(tiled_mma, accum, tPr, tCrB, frag_src); } }; diff --git a/include/cute/arch/mma_xe.hpp b/include/cute/arch/mma_xe.hpp index 96e3228df..c510cb26e 100644 --- a/include/cute/arch/mma_xe.hpp +++ b/include/cute/arch/mma_xe.hpp @@ -61,10 +61,10 @@ SYCL_DEVICE_OCL(cute::intel::int4 intel_sub_group_u8_u8_matrix_mad_k32(cute::int SYCL_DEVICE_OCL(cute::intel::int2 intel_sub_group_u8_u8_matrix_mad_k32(cute::intel::ushort2 a, cute::intel::uint8 b, cute::intel::int2 acc)); SYCL_DEVICE_OCL(int intel_sub_group_u8_u8_matrix_mad_k32(ushort a, cute::intel::uint8 b, int acc)); // mma_tf32 -SYCL_DEVICE_OCL(cute::intel::float8 intel_sub_group_tf32_tf32_matrix_mad_k8_f32(cute::intel::float4 a, cute::intel::float8 b, cute::intel::float8 acc)); -SYCL_DEVICE_OCL(cute::intel::float4 intel_sub_group_tf32_tf32_matrix_mad_k8_f32(cute::intel::float2 a, cute::intel::float8 b, cute::intel::float4 acc)); -SYCL_DEVICE_OCL(cute::intel::float2 intel_sub_group_tf32_tf32_matrix_mad_k8_f32(float a, cute::intel::float8 b, cute::intel::float2 acc)); -SYCL_DEVICE_OCL(float intel_sub_group_tf32_tf32_matrix_mad_k8_f32(float a, cute::intel::float8 b, float acc)); +SYCL_DEVICE_OCL(cute::intel::float8 intel_sub_group_tf32_tf32_matrix_mad_k8(cute::intel::float4 a, cute::intel::float8 b, cute::intel::float8 acc)); +SYCL_DEVICE_OCL(cute::intel::float4 intel_sub_group_tf32_tf32_matrix_mad_k8(cute::intel::float2 a, cute::intel::float8 b, cute::intel::float4 acc)); +SYCL_DEVICE_OCL(cute::intel::float2 intel_sub_group_tf32_tf32_matrix_mad_k8(float a, cute::intel::float8 b, cute::intel::float2 acc)); +SYCL_DEVICE_OCL(float intel_sub_group_tf32_tf32_matrix_mad_k8(float a, cute::intel::float8 b, float acc)); #undef SYCL_DEVICE_OCL @@ -430,7 +430,7 @@ struct XE_8x16x8_F32TF32TF32F32_TT intel::float8 const& c) { #if defined(SYCL_INTEL_TARGET) - d = intel_sub_group_tf32_tf32_matrix_mad_k8_f32(a, b, c); + d = intel_sub_group_tf32_tf32_matrix_mad_k8(a, b, c); #else CUTE_INVALID_CONTROL_PATH("Attempting to use XE_8x16x8_F32TF32TF32F32_TT on non-PVC hardware"); #endif @@ -451,7 +451,7 @@ struct XE_4x16x8_F32TF32TF32F32_TT intel::float4 const& c) { #if defined(SYCL_INTEL_TARGET) - d = intel_sub_group_tf32_tf32_matrix_mad_k8_f32(a, b, c); + d = intel_sub_group_tf32_tf32_matrix_mad_k8(a, b, c); #else CUTE_INVALID_CONTROL_PATH("Attempting to use XE_4x16x8_F32TF32TF32F32_TT on non-PVC hardware"); #endif @@ -472,7 +472,7 @@ struct XE_2x16x8_F32TF32TF32F32_TT intel::float2 const& c) { #if defined(SYCL_INTEL_TARGET) - d = intel_sub_group_tf32_tf32_matrix_mad_k8_f32(a, b, c); + d = intel_sub_group_tf32_tf32_matrix_mad_k8(a, b, c); #else CUTE_INVALID_CONTROL_PATH("Attempting to use XE_2x16x8_F32TF32TF32F32_TT on non-PVC hardware"); #endif @@ -493,7 +493,7 @@ struct XE_1x16x8_F32TF32TF32F32_TT float const& c) { #if defined(SYCL_INTEL_TARGET) - d = intel_sub_group_tf32_tf32_matrix_mad_k8_f32(a, b, c); + d = intel_sub_group_tf32_tf32_matrix_mad_k8(a, b, c); #else CUTE_INVALID_CONTROL_PATH("Attempting to use XE_1x16x8_F32TF32TF32F32_TT on non-PVC hardware"); #endif diff --git a/include/cute/arch/xe_copy_1B.hpp b/include/cute/arch/xe_copy_1B.hpp index 8323abb9a..ce8193d35 100644 --- a/include/cute/arch/xe_copy_1B.hpp +++ b/include/cute/arch/xe_copy_1B.hpp @@ -241,6 +241,8 @@ struct XE_2D_U8x1x32_LD_N { }; struct XE_2D_U8x2x32_LD_N { + using BlockShape = Shape<_2, _32>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -257,6 +259,8 @@ struct XE_2D_U8x2x32_LD_N { }; struct XE_2D_U8x2x32_ST_N { + using BlockShape = Shape<_2, _32>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -273,6 +277,8 @@ struct XE_2D_U8x2x32_ST_N { }; struct XE_2D_U8x4x32_LD_N { + using BlockShape = Shape<_4, _32>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -289,6 +295,8 @@ struct XE_2D_U8x4x32_LD_N { }; struct XE_2D_U8x8x32_LD_N { + using BlockShape = Shape<_8, _32>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -305,6 +313,8 @@ struct XE_2D_U8x8x32_LD_N { }; struct XE_2D_U8x16x32_LD_N { + using BlockShape = Shape<_16, _32>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -336,6 +346,8 @@ struct XE_2D_U8x16x32_LD_N { }; struct XE_2D_U8x32x32_LD_N { + using BlockShape = Shape<_32, _32>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -382,6 +394,8 @@ struct XE_2D_U8x1x64_LD_N { }; struct XE_2D_U8x2x64_LD_N { + using BlockShape = Shape<_2, _64>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -412,6 +426,8 @@ struct XE_2D_U8x2x64_LD_N { }; struct XE_2D_U8x4x64_LD_N { + using BlockShape = Shape<_4, _64>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -442,6 +458,8 @@ struct XE_2D_U8x4x64_LD_N { }; struct XE_2D_U8x8x64_LD_N { + using BlockShape = Shape<_8, _64>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -472,6 +490,8 @@ struct XE_2D_U8x8x64_LD_N { }; struct XE_2D_U8x16x64_LD_N { + using BlockShape = Shape<_16, _64>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -503,6 +523,8 @@ struct XE_2D_U8x16x64_LD_N { }; struct XE_2D_U8x32x64_LD_N { + using BlockShape = Shape<_32, _64>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -536,6 +558,8 @@ struct XE_2D_U8x32x64_LD_N { struct XE_2D_U8x32x16_LD_V { + using BlockShape = Shape<_32, _16>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -566,6 +590,8 @@ struct XE_2D_U8x32x16_LD_V { }; struct XE_2D_U8x32x32_LD_V { + using BlockShape = Shape<_32, _32>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -582,6 +608,8 @@ struct XE_2D_U8x32x32_LD_V { }; struct XE_2D_U8x32x64_LD_V { + using BlockShape = Shape<_32, _64>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -598,6 +626,8 @@ struct XE_2D_U8x32x64_LD_V { }; struct XE_2D_U8x1x16_ST_N { + using BlockShape = Shape<_1, _16>; + template CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -614,6 +644,8 @@ struct XE_2D_U8x1x16_ST_N { }; struct XE_2D_U8x2x16_ST_N { + using BlockShape = Shape<_2, _16>; + template CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -630,6 +662,8 @@ struct XE_2D_U8x2x16_ST_N { }; struct XE_2D_U8x4x16_ST_N { + using BlockShape = Shape<_4, _16>; + template CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, int pitch, intel::coord_t coord, diff --git a/include/cute/arch/xe_copy_2B.hpp b/include/cute/arch/xe_copy_2B.hpp index ac8d7d3ed..d3560f1fe 100644 --- a/include/cute/arch/xe_copy_2B.hpp +++ b/include/cute/arch/xe_copy_2B.hpp @@ -257,6 +257,8 @@ SYCL_DEVICE_OCL(void intel_sub_group_2d_block_prefetch_16b_16r16x1c( namespace cute { struct XE_2D_U16x1x16_LD_N { + using BlockShape = Shape<_1, _16>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -273,6 +275,8 @@ struct XE_2D_U16x1x16_LD_N { }; struct XE_2D_U16x2x16_LD_N { + using BlockShape = Shape<_2, _16>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -289,6 +293,8 @@ struct XE_2D_U16x2x16_LD_N { }; struct XE_2D_U16x4x16_LD_N { + using BlockShape = Shape<_4, _16>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -305,6 +311,8 @@ struct XE_2D_U16x4x16_LD_N { }; struct XE_2D_U16x8x16_LD_N { + using BlockShape = Shape<_8, _16>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -336,6 +344,8 @@ struct XE_2D_U16x8x16_LD_N { }; struct XE_2D_U16x16x16_LD_N { + using BlockShape = Shape<_16, _16>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -367,6 +377,8 @@ struct XE_2D_U16x16x16_LD_N { }; struct XE_2D_U16x32x16_LD_N { + using BlockShape = Shape<_32, _16>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -398,6 +410,8 @@ struct XE_2D_U16x32x16_LD_N { }; struct XE_2D_U16x1x32_LD_N { + using BlockShape = Shape<_1, _32>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -428,6 +442,8 @@ struct XE_2D_U16x1x32_LD_N { }; struct XE_2D_U16x2x32_LD_N { + using BlockShape = Shape<_2, _32>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -458,6 +474,8 @@ struct XE_2D_U16x2x32_LD_N { }; struct XE_2D_U16x4x32_LD_N { + using BlockShape = Shape<_4, _32>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -488,6 +506,8 @@ struct XE_2D_U16x4x32_LD_N { }; struct XE_2D_U16x8x32_LD_N { + using BlockShape = Shape<_8, _32>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -519,6 +539,8 @@ struct XE_2D_U16x8x32_LD_N { }; struct XE_2D_U16x16x32_LD_N { + using BlockShape = Shape<_16, _32>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -550,6 +572,8 @@ struct XE_2D_U16x16x32_LD_N { }; struct XE_2D_U16x32x32_LD_N { + using BlockShape = Shape<_32, _32>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -582,6 +606,8 @@ struct XE_2D_U16x32x32_LD_N { }; struct XE_2D_U16x16x16_LD_V { + using BlockShape = Shape<_16, _16>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -613,6 +639,8 @@ struct XE_2D_U16x16x16_LD_V { }; struct XE_2D_U16x32x16_LD_V { + using BlockShape = Shape<_32, _16>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -644,6 +672,8 @@ struct XE_2D_U16x32x16_LD_V { }; struct XE_2D_U16x16x32_LD_V { + using BlockShape = Shape<_16, _32>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -675,6 +705,8 @@ struct XE_2D_U16x16x32_LD_V { }; struct XE_2D_U16x32x32_LD_V { + using BlockShape = Shape<_32, _32>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -706,8 +738,11 @@ struct XE_2D_U16x32x32_LD_V { }; struct XE_2D_U16x16x8_LD_T { + using BlockShape = Shape<_8, _16>; using inst_dtype = uint32_t; + static constexpr bool is_transpose = true; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -724,7 +759,11 @@ struct XE_2D_U16x16x8_LD_T { }; struct XE_2D_U16x16x16_LD_T { + using BlockShape = Shape<_16, _16>; using inst_dtype = uint32_t; + + static constexpr bool is_transpose = true; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -741,6 +780,8 @@ struct XE_2D_U16x16x16_LD_T { }; struct XE_2D_U16x1x16_ST_N { + using BlockShape = Shape<_1, _16>; + template CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -757,6 +798,8 @@ struct XE_2D_U16x1x16_ST_N { }; struct XE_2D_U16x2x16_ST_N { + using BlockShape = Shape<_2, _16>; + template CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -773,6 +816,8 @@ struct XE_2D_U16x2x16_ST_N { }; struct XE_2D_U16x4x16_ST_N { + using BlockShape = Shape<_4, _16>; + template CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -789,6 +834,8 @@ struct XE_2D_U16x4x16_ST_N { }; struct XE_2D_U16x8x16_ST_N { + using BlockShape = Shape<_8, _16>; + template CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, int pitch, intel::coord_t coord, diff --git a/include/cute/arch/xe_copy_4B.hpp b/include/cute/arch/xe_copy_4B.hpp index 78cd1471c..2198b9dcc 100644 --- a/include/cute/arch/xe_copy_4B.hpp +++ b/include/cute/arch/xe_copy_4B.hpp @@ -272,6 +272,8 @@ SYCL_DEVICE_OCL(void intel_sub_group_2d_block_prefetch_32b_16r8x1c( namespace cute { struct XE_2D_U32x1x16_LD_N { + using BlockShape = Shape<_1, _16>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -288,6 +290,8 @@ struct XE_2D_U32x1x16_LD_N { }; struct XE_2D_U32x2x16_LD_N { + using BlockShape = Shape<_2, _16>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -304,6 +308,8 @@ struct XE_2D_U32x2x16_LD_N { }; struct XE_2D_U32x4x16_LD_N { + using BlockShape = Shape<_4, _16>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -320,6 +326,8 @@ struct XE_2D_U32x4x16_LD_N { }; struct XE_2D_U32x8x16_LD_N { + using BlockShape = Shape<_8, _16>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -336,6 +344,8 @@ struct XE_2D_U32x8x16_LD_N { }; struct XE_2D_U32x16x16_LD_N { + using BlockShape = Shape<_16, _16>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -352,6 +362,8 @@ struct XE_2D_U32x16x16_LD_N { }; struct XE_2D_U32x32x16_LD_N { + using BlockShape = Shape<_32, _16>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -368,6 +380,8 @@ struct XE_2D_U32x32x16_LD_N { }; struct XE_2D_TF32x1x8_LD_N { + using BlockShape = Shape<_32, _16>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -384,6 +398,9 @@ struct XE_2D_TF32x1x8_LD_N { }; struct XE_2D_TF32x2x8_LD_N { + using BlockShape = Shape<_2, _8>; + using ValueShape = Shape<_1, _16>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -400,6 +417,9 @@ struct XE_2D_TF32x2x8_LD_N { }; struct XE_2D_TF32x4x8_LD_N { + using BlockShape = Shape<_4, _8>; + using ValueShape = Shape<_2, _16>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -416,6 +436,9 @@ struct XE_2D_TF32x4x8_LD_N { }; struct XE_2D_TF32x8x8_LD_N { + using BlockShape = Shape<_8, _8>; + using ValueShape = Shape<_4, _16>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -432,6 +455,9 @@ struct XE_2D_TF32x8x8_LD_N { }; struct XE_2D_TF32x16x8_LD_N { + using BlockShape = Shape<_16, _8>; + using ValueShape = Shape<_8, _16>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -448,6 +474,9 @@ struct XE_2D_TF32x16x8_LD_N { }; struct XE_2D_TF32x32x8_LD_N { + using BlockShape = Shape<_32, _8>; + using ValueShape = Shape<_16, _16>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -464,6 +493,8 @@ struct XE_2D_TF32x32x8_LD_N { }; struct XE_2D_TF32x1x16_LD_N { + using BlockShape = Shape<_1, _16>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -480,6 +511,9 @@ struct XE_2D_TF32x1x16_LD_N { }; struct XE_2D_TF32x2x16_LD_N { + using BlockShape = Shape<_2, _16>; + using ValueShape = Shape<_1, _32>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -496,6 +530,9 @@ struct XE_2D_TF32x2x16_LD_N { }; struct XE_2D_TF32x4x16_LD_N { + using BlockShape = Shape<_4, _16>; + using ValueShape = Shape<_2, _32>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -512,6 +549,9 @@ struct XE_2D_TF32x4x16_LD_N { }; struct XE_2D_TF32x8x16_LD_N { + using BlockShape = Shape<_8, _16>; + using ValueShape = Shape<_4, _32>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -528,6 +568,9 @@ struct XE_2D_TF32x8x16_LD_N { }; struct XE_2D_TF32x16x16_LD_N { + using BlockShape = Shape<_16, _16>; + using ValueShape = Shape<_8, _32>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -544,6 +587,9 @@ struct XE_2D_TF32x16x16_LD_N { }; struct XE_2D_TF32x32x16_LD_N { + using BlockShape = Shape<_32, _16>; + using ValueShape = Shape<_16, _32>; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -561,6 +607,8 @@ struct XE_2D_TF32x32x16_LD_N { struct XE_2D_U32x16x1_LD_T { + static constexpr bool is_transpose = true; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -577,6 +625,10 @@ struct XE_2D_U32x16x1_LD_T { }; struct XE_2D_U32x16x2_LD_T { + using BlockShape = Shape<_2, _16>; + + static constexpr bool is_transpose = true; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -593,6 +645,10 @@ struct XE_2D_U32x16x2_LD_T { }; struct XE_2D_U32x16x4_LD_T { + using BlockShape = Shape<_4, _16>; + + static constexpr bool is_transpose = true; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -609,6 +665,10 @@ struct XE_2D_U32x16x4_LD_T { }; struct XE_2D_U32x16x8_LD_T { + using BlockShape = Shape<_8, _16>; + + static constexpr bool is_transpose = true; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -639,6 +699,8 @@ struct XE_2D_U32x16x8_LD_T { }; struct XE_2D_U32x1x16_ST_N { + using BlockShape = Shape<_1, _16>; + template CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -655,6 +717,8 @@ struct XE_2D_U32x1x16_ST_N { }; struct XE_2D_U32x2x16_ST_N { + using BlockShape = Shape<_2, _16>; + template CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -671,6 +735,8 @@ struct XE_2D_U32x2x16_ST_N { }; struct XE_2D_U32x4x16_ST_N { + using BlockShape = Shape<_4, _16>; + template CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -687,6 +753,8 @@ struct XE_2D_U32x4x16_ST_N { }; struct XE_2D_U32x8x16_ST_N { + using BlockShape = Shape<_8, _16>; + template CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, int pitch, intel::coord_t coord, diff --git a/include/cute/arch/xe_copy_8B.hpp b/include/cute/arch/xe_copy_8B.hpp index f340fb5ac..681263e25 100644 --- a/include/cute/arch/xe_copy_8B.hpp +++ b/include/cute/arch/xe_copy_8B.hpp @@ -89,6 +89,8 @@ SYCL_DEVICE_OCL(intel::ulong4 intel_sub_group_block_read_transpose_64b_8r4c( namespace cute { struct XE_2D_U64x8x1_LD_T { + static constexpr bool is_transpose = true; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -105,6 +107,8 @@ struct XE_2D_U64x8x1_LD_T { }; struct XE_2D_U64x8x2_LD_T { + static constexpr bool is_transpose = true; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, @@ -121,6 +125,8 @@ struct XE_2D_U64x8x2_LD_T { }; struct XE_2D_U64x8x4_LD_T { + static constexpr bool is_transpose = true; + template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, diff --git a/include/cute/atom/copy_traits_xe.hpp b/include/cute/atom/copy_traits_xe.hpp index daf78eb04..93d86b51a 100644 --- a/include/cute/atom/copy_traits_xe.hpp +++ b/include/cute/atom/copy_traits_xe.hpp @@ -38,58 +38,129 @@ namespace cute { -namespace detail -{ - template - struct is_transpose : bool_constant {}; +namespace detail { + +static constexpr auto subgroup_size = 16; + +// ========== size_of_inst ========== +template +static constexpr auto size_of_inst = sizeof(dtype); + +template +static constexpr auto size_of_inst> = sizeof(typename T::inst_dtype); + - template<> - struct is_transpose : bool_constant{}; +// ========== value_layout_t ========== +template +struct value_layout_t { + using type = decltype(make_layout(make_shape(get<0>(typename T::BlockShape{}), + get<1>(typename T::BlockShape{}) + / Int{}))); +}; + +template +struct value_layout_t> { + using type = decltype(make_layout(make_shape(get<0>(typename T::ValueShape{}), + get<1>(typename T::ValueShape{}) + / Int{}))); +}; - template<> - struct is_transpose : bool_constant{}; - template<> - struct is_transpose : bool_constant{}; +// ========== is_transpose_load ========== +template +static constexpr bool is_transpose_load = false; - template<> - struct is_transpose : bool_constant{}; +template +static constexpr bool is_transpose_load>> = T::is_transpose; - template<> - struct is_transpose : bool_constant{}; - template<> - struct is_transpose : bool_constant{}; +// ========== is_stride_leftmost ========== +template +static constexpr bool is_stride_leftmost = std::is_same_v<_1, decltype(get<0>(T{}))>; - template<> - struct is_transpose : bool_constant{}; +template +static constexpr bool is_stride_leftmost> = std::is_same_v<_1, decltype(get<0>(T{}.stride()))>; - template<> - struct is_transpose : bool_constant{}; - template constexpr bool has_inst_dtype = false; +} // end namespace detail + - template - constexpr bool has_inst_dtype> = true; -} // namespace detail end +template +static constexpr auto append_pvc_tensor(Tensor_t const &t0, uint32_t shape, uint32_t stride) { + return make_tensor(make_inttuple_iter(t0.data()), append(t0.layout(), make_layout(shape, E{} * stride))); +} -template struct XE_2D_LD_Unpack { +template , int64_t>> +struct XE_2D_LD_Unpack { + + using BlockShape = CopyOp::BlockShape; + using Value_Layout = typename detail::value_layout_t::type; + using Traits_LD_t = Copy_Traits; + + static constexpr auto stride_rank = rank(StrideIndicator{}); + static_assert(stride_rank == 2 || stride_rank == 3); + + // Assume LD_T/LD_N will indicate ColumnMajor and RowMajor + static constexpr bool is_column_major = detail::is_transpose_load; + + // We need reverse some parameters becasue intel xe 2d copy intrinsic always assume the matrix is (M, N):(N, 1) convention + static constexpr bool is_need_reversed = detail::is_stride_leftmost; + + // For a logic matrix M-rows and N-columns, user can pass it with the convention (M, N):(N, 1), also can pass it with convention (N, M):(1, N). + // It mean (M, N):(N, 1) convention if 'is_convention_MN' is true, (N, M):(1, N) convention otherwise. + static constexpr bool is_convention_MN = !(is_need_reversed ^ is_column_major); + + // 2d copy parameters const void *base_ptr; uint32_t width; uint32_t height; uint32_t pitch; + uint32_t stride_l = 0; + + + + XE_2D_LD_Unpack(const void *ptr, uint32_t y, + uint32_t x, uint32_t p = 0) : base_ptr(ptr) { + if constexpr (is_need_reversed) { + width = y; + height = x; + } + else { + width = x; + height = y; + } - XE_2D_LD_Unpack(const void *ptr, uint32_t const &w, - uint32_t const &h, uint32_t const &p) - : base_ptr(ptr), width(w), height(h), pitch(p) {} + pitch = (p == 0 ? width : p); + } + + template + XE_2D_LD_Unpack(Tensor const &tensor) { + base_ptr = tensor.data().get(); + + if constexpr (is_need_reversed) + { + width = size<0>(tensor.shape()); + height = size<1>(tensor.shape()); + pitch = size<1>(tensor.stride()); + } + else + { + width = size<1>(tensor.shape()); + height = size<0>(tensor.shape()); + pitch = size<0>(tensor.stride()); + } + + if constexpr (stride_rank == 3) { + stride_l = size<2>(tensor.stride()); + } + } - template - XE_2D_LD_Unpack(TraitsArgs const &traits) : base_ptr(traits.base_ptr), - width(traits.width), height(traits.height), pitch(traits.pitch) {} + XE_2D_LD_Unpack(Traits_LD_t const &traits) : base_ptr(traits.base_ptr), + width(traits.width), height(traits.height), pitch(traits.pitch), + stride_l(traits.stride_l) {} XE_2D_LD_Unpack() {} - using Traits_LD_t = Copy_Traits; template CUTE_HOST_DEVICE friend constexpr void @@ -100,19 +171,23 @@ template struct XE_2D_LD_Unpack { using dtype = typename Tensor::value_type; dtype *base_addr = (dtype *)traits.base_ptr; - + + int x, y; auto [m, n, l] = src.data().coord_; - - auto inst_size = sizeof(dtype); - - if constexpr (detail::has_inst_dtype) { - inst_size = sizeof(typename CopyOp::inst_dtype); + if constexpr (is_need_reversed) { + x = m; + y = n; + } else { + x = n; + y = m; } - CopyOp::copy(base_addr + l * traits.width * traits.height, + static constexpr auto inst_size = detail::size_of_inst; + + CopyOp::copy(base_addr + l * traits.stride_l, traits.width * sizeof(dtype), traits.height, traits.pitch * sizeof(dtype), - intel::coord_t{(int)(n * sizeof(dtype) / inst_size), (int)(m)}, + intel::coord_t{(int)(x * sizeof(dtype) / inst_size), y}, &*dst.data()); } @@ -128,12 +203,32 @@ template struct XE_2D_LD_Unpack { auto [m, n, l] = src.data().coord_; - CopyOp::PREFETCH::copy((void *)(base_addr + l * atom.width * atom.height), + CopyOp::PREFETCH::copy((void *)(base_addr + l * atom.stride_l), atom.width * sizeof(dtype), atom.height, atom.pitch * sizeof(dtype), intel::coord_t{(int)n, (int)m}); } + template + CUTE_HOST_DEVICE constexpr auto get_pvc_tensor(Coord const &coord, + GShape const &shape) const { + + auto R = rank(GShape{}); + static_assert(R == 3, "mismatch rank"); + + using basis_t = make_seq; + + using shape_mn = std::conditional_t; + + auto new_shape = cute::tuple_cat(make_shape(_1{}), take(shape)); + auto new_stride = cute::tuple_cat(make_stride(_1{}), transform(basis_t{}, shape_mn{}, + [&](auto i, auto s){ + return E{} * s; + })); + return make_tensor(make_inttuple_iter(coord), + make_layout(new_shape, new_stride)); + } + template {})> CUTE_HOST_DEVICE constexpr auto get_pvc_tensor(GCoord const &coord, GShape const &shape, @@ -150,29 +245,60 @@ template struct XE_2D_LD_Unpack { make_layout(t_shape, t_stride)); } - template - static constexpr auto with(T1 && arg1, T2 && arg2, TraitsArgs&&... args) { - return Traits_LD_t{arg1, arg2, args...}; + template + static constexpr auto with(Tensor const &tensor) { + return Traits_LD_t{tensor}; + } + + template + static constexpr auto with(T0 && arg0, T1 && arg1, Ts&&... args) { + return Traits_LD_t{arg0, arg1, args...}; } }; -template struct XE_2D_ST_Unpack { +template , int64_t>> struct XE_2D_ST_Unpack { + using Traits_ST_t = Copy_Traits; + using BlockShape = CopyOp::BlockShape; + using Value_Layout = decltype(make_layout(make_shape(get<0>(BlockShape{}), + get<1>(BlockShape{}) + / Int{}))); + + static constexpr auto stride_rank = rank(StrideIndicator{}); + static_assert(stride_rank == 2 || stride_rank == 3); + + static constexpr bool is_convention_MN = true; + const void *base_ptr; uint32_t width; uint32_t height; uint32_t pitch; + uint32_t stride_l = 0; + + XE_2D_ST_Unpack(const void *ptr, uint32_t y, + uint32_t x, uint32_t p = 0) : base_ptr(ptr) { + width = x; + height = y; + pitch = (p == 0 ? width : p); + } + + template + XE_2D_ST_Unpack(Tensor const &tensor) { + base_ptr = tensor.data().get(); + width = size<1>(tensor.shape()); + height = size<0>(tensor.shape()); + pitch = size<0>(tensor.stride()); - XE_2D_ST_Unpack(const void *ptr, uint32_t const &w, - uint32_t const &h, uint32_t const &p) - : base_ptr(ptr), width(w), height(h), pitch(p) {} + if constexpr (stride_rank == 3) { + stride_l = size<2>(tensor.stride()); + } + } - template - XE_2D_ST_Unpack(TraitsArgs const &traits) : base_ptr(traits.base_ptr), - width(traits.width), height(traits.height), pitch(traits.pitch) {} + XE_2D_ST_Unpack(Traits_ST_t const &traits) : base_ptr(traits.base_ptr), + width(traits.width), height(traits.height), pitch(traits.pitch), + stride_l(traits.stride_l) {} XE_2D_ST_Unpack() {} - using Traits_ST_t = Copy_Traits; template CUTE_HOST_DEVICE friend constexpr void @@ -186,12 +312,30 @@ template struct XE_2D_ST_Unpack { auto [m, n, l] = dst.data().coord_; - CopyOp::copy(base_addr + l * traits.width * traits.height, + CopyOp::copy(base_addr + l * traits.stride_l, (int)(traits.width * sizeof(dtype)), (int)(traits.height), (int)(traits.pitch * sizeof(dtype)), intel::coord_t{(int)n, (int)m}, &*src.data()); } + template + CUTE_HOST_DEVICE constexpr auto get_pvc_tensor(Coord const &coord, + GShape const &shape) const { + + auto R = rank(GShape{}); + static_assert(R == 3, "mismatch rank"); + + using basis_t = make_seq; + + auto new_shape = cute::tuple_cat(make_shape(_1{}), take(shape)); + auto new_stride = cute::tuple_cat(make_stride(_1{}), transform(basis_t{}, BlockShape{}, + [&](auto i, auto s){ + return E{} * s; + })); + return make_tensor(make_inttuple_iter(coord), + make_layout(new_shape, new_stride)); + } + template {})> CUTE_HOST_DEVICE constexpr auto get_pvc_tensor(GCoord const &coord, GShape const &shape, @@ -208,10 +352,16 @@ template struct XE_2D_ST_Unpack { make_layout(t_shape, t_stride)); } - template - static constexpr auto with(T1 && arg1, T2 && arg2, TraitsArgs&&... args) { - return Traits_ST_t{arg1, arg2, args...}; + template + static constexpr auto with(Tensor const &tensor) { + return Traits_ST_t{tensor}; + } + + template + static constexpr auto with(T0 && arg0, T1 && arg1, Ts&&... args) { + return Traits_ST_t{arg0, arg1, args...}; } + }; // clang-format off @@ -219,7 +369,6 @@ template struct XE_2D_ST_Unpack { template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_1, _32>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -238,7 +387,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_2, _32>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -257,7 +405,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_4, _32>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -276,7 +423,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_8, _32>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -295,7 +441,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_16, _32>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -314,7 +459,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_32, _32>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -333,7 +477,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_1, _64>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -352,7 +495,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_1, _64>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -369,7 +511,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_2, _64>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -389,7 +530,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_2, _64>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -406,7 +546,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_4, _64>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -425,7 +564,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_4, _64>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -442,7 +580,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_8, _64>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -461,7 +598,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_8, _64>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -479,7 +615,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_16, _64>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -497,7 +632,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_16, _64>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout>, @@ -512,7 +646,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_32, _64>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -530,7 +663,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_32, _64>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout>, @@ -545,7 +677,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_1, _16>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -564,7 +695,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_2, _16>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -583,7 +713,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_4, _16>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -602,7 +731,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_8, _16>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -622,7 +750,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_8, _16>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -638,7 +765,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_16, _16>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -657,7 +783,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_16, _16>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -673,7 +798,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_32, _16>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -692,7 +816,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_32, _16>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -708,7 +831,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_1, _32>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -727,7 +849,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_1, _32>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -743,7 +864,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_2, _32>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -762,7 +882,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_2, _32>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -778,7 +897,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_4, _32>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -797,7 +915,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_4, _32>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -813,7 +930,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_8, _32>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -832,7 +948,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_8, _32>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -849,7 +964,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_16, _32>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -868,7 +982,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_16, _32>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -885,7 +998,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_32, _32>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -904,7 +1016,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_32, _32>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -921,7 +1032,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_1, _8>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -940,7 +1050,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_2, _8>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -959,7 +1068,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_4, _8>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -978,7 +1086,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_8, _8>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -997,7 +1104,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_16, _8>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -1016,7 +1122,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_32, _8>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -1035,7 +1140,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_1, _16>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -1054,7 +1158,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_2, _16>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -1073,7 +1176,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_4, _16>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -1092,7 +1194,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_8, _16>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -1111,14 +1212,13 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_16, _16>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, Stride< _0, _1>>; // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout,Shape <_32, _2, _16>>, - Stride,Stride< _1,_256,_1024>>>; + using DstLayout = Layout>, + Stride<_0,Stride< _512, _1>>>; // Reference map from (thr,val) to bit using RefLayout = DstLayout; @@ -1130,14 +1230,13 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_32, _16>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, Stride< _0, _1>>; // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout,Shape <_32, _2, _32>>, - Stride,Stride< _1,_256,_1024>>>; + using DstLayout = Layout>, + Stride< _0, Stride<_512, Int<512 * 16>, _1>>>; // Reference map from (thr,val) to bit using RefLayout = DstLayout; @@ -1149,7 +1248,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_1, _16>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -1168,7 +1266,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_2, _16>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -1187,7 +1284,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_4, _16>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -1206,7 +1302,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_8, _16>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -1225,7 +1320,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_16, _16>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -1244,7 +1338,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_32, _16>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -1263,7 +1356,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_32, _16>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -1282,7 +1374,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_32, _16>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -1298,7 +1389,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_32, _32>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -1317,7 +1407,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_32, _64>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -1336,7 +1425,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_16, _16>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -1356,7 +1444,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_32, _16>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -1375,7 +1462,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_32, _32>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -1395,7 +1481,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_16, _32>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -1415,7 +1500,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_8,_16>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -1434,7 +1518,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_16,_16>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout, @@ -1468,7 +1551,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_2,_16>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -1488,7 +1570,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_4,_16>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -1508,7 +1589,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_8,_16>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -1528,7 +1608,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_8,_16>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -1544,7 +1623,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_1,_8>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -1564,7 +1642,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_2,_8>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -1584,7 +1661,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_LD_Unpack { - using Shape_MN = Shape<_4,_8>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -1604,7 +1680,7 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_ST_Unpack { - using Shape_MN = Shape<_2,_32>; + using BlockShape = Shape<_2,_32>; using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout = Layout>, @@ -1623,7 +1699,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_ST_Unpack { - using Shape_MN = Shape<_1, _16>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -1643,7 +1718,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_ST_Unpack { - using Shape_MN = Shape<_2, _16>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -1663,7 +1737,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_ST_Unpack { - using Shape_MN = Shape<_4, _16>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -1679,7 +1752,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_ST_Unpack { - using Shape_MN = Shape<_8, _16>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -1700,7 +1772,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_ST_Unpack { - using Shape_MN = Shape<_8, _32>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -1721,7 +1792,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_ST_Unpack { - using Shape_MN = Shape<_1, _16>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -1741,7 +1811,7 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_ST_Unpack { - using Shape_MN = Shape<_2, _16>; + using BlockShape = Shape<_2, _16>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -1761,7 +1831,7 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_ST_Unpack { - using Shape_MN = Shape<_4, _16>; + using BlockShape = Shape<_4, _16>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -1781,7 +1851,7 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_ST_Unpack { - using Shape_MN = Shape<_8, _16>; + using BlockShape = Shape<_8, _16>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -1801,7 +1871,7 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_ST_Unpack { - using Shape_MN = Shape<_1, _16>; + using BlockShape = Shape<_1, _16>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -1821,7 +1891,7 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_ST_Unpack { - using Shape_MN = Shape<_2, _16>; + using BlockShape = Shape<_2, _16>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -1841,7 +1911,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_ST_Unpack { - using Shape_MN = Shape<_4, _16>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -1861,7 +1930,6 @@ struct Copy_Traits template struct Copy_Traits : XE_2D_ST_Unpack { - using Shape_MN = Shape<_8, _16>; // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit @@ -1878,14 +1946,6 @@ struct Copy_Traits : XE_2D_ST_Unpack(args...) {} }; -template -auto make_xe_2d_copy(Tensor gtensor) { - using GTensor = Tensor; - using Traits = Copy_Traits; - // Traits traits {gtensor}; - return Copy_Atom{gtensor}; -} - template struct Copy_Traits> { // Logical thread id to thread idx (one-thread) @@ -2022,4 +2082,136 @@ namespace detail } } // end namespace detail +template +class Xe2DThrCopy : ThrCopy { + +public: + + CUTE_HOST_DEVICE + Xe2DThrCopy(ThrIdx const& thr_idx) : ThrCopy (thr_idx) {} + + template + CUTE_HOST_DEVICE + auto + retile_D(DTensor&& dtensor) { + if constexpr (!TiledCopy::is_convention_MN) { + return retile_D_nkl(dtensor); + } else { + return retile_D_mkl(dtensor); + } + } + + template + CUTE_HOST_DEVICE + auto + retile_MMA(MMA const&, MMATensor&& mma_tensor) { + if constexpr (TiledCopy::is_convention_MN) { + static constexpr auto m = decltype(size<1>(mma_tensor.shape()))::value; + static constexpr auto k = decltype(size<2>(mma_tensor.shape()))::value; + static constexpr auto m_step = size<0>(typename TiledCopy::BlockShape{}) + / size<0>(typename MMA::Shape_MNK{}); + static constexpr auto k_step = size<1>(typename TiledCopy::BlockShape{}) + / size<2>(typename MMA::Shape_MNK{}); + + auto retiled_tensor = make_tensor(mma_tensor.data(), + make_shape(size<0>(mma_tensor.shape()), + Int{}, + Int{}, + Int{}, + Int{})); + return make_tensor(mma_tensor.data(),group<2, 4>(group<1, 3>(select<0, 1, 3, 2, 4>(retiled_tensor.layout())))); + } else { + static constexpr auto k = decltype(size<2>(mma_tensor.shape()))::value; + static constexpr auto k_step = size<0>(typename TiledCopy::BlockShape{}) + / size<2>(typename MMA::Shape_MNK{}); + + auto retiled_tensor = make_tensor(mma_tensor.data(), + make_shape(size<0>(mma_tensor.shape()), + Int{}, + size<1>(mma_tensor.shape()), + Int{})); + return make_tensor(mma_tensor.data(),group<2, 4>(select<0, 2, 1, 3>(retiled_tensor.layout()))); + } + } + +private: + + template + CUTE_HOST_DEVICE static + auto + retile_D_mkl(DTensor&& dtensor) { + auto tmp = ThrCopy::retile_D(dtensor); + return make_tensor(static_cast(tmp).data(), + tmp.shape()); + } + + template + CUTE_HOST_DEVICE static + auto + retile_D_nkl(DTensor&& dtensor) { + auto b_tensor = make_tensor(dtensor.data(), + make_shape(size<0>(dtensor.shape()), + size<2>(dtensor.shape()), + size<1>(dtensor.shape()))); + auto tmp = ThrCopy::retile_D(b_tensor); + return make_tensor(static_cast(tmp).data(), + make_shape(size<0>(tmp.shape()), + size<2>(tmp.shape()), + size<1>(tmp.shape()))); + } +}; + +template coord [Need not be 2D...] + class ShapeTiler_MN> // coord space +struct Xe2DTiledCopy : TiledCopy{ + + template ::value)> + CUTE_HOST_DEVICE + auto + get_slice(ThrIdx const& thr_idx) const + { + return Xe2DThrCopy(thr_idx); + } +}; + +template ::Value_Layout> +CUTE_HOST_DEVICE +auto +make_xe_2d_copy(Copy_Atom const& copy_atom, + ThrLayout const& thr_layout = {}, // (m,n) -> thr_idx + ValLayout const& val_layout = {}) // (m,n) -> val_idx +{ + // Take the raked_products to compute the Layout_MN + // (M,N) -> (thr_idx, val_idx) + auto layout_mn = raked_product(thr_layout, val_layout); + // (thr_idx, val_idx) -> (M,N) + auto layout_tv = right_inverse(layout_mn).with_shape(make_shape(size(thr_layout), size(val_layout))); + // Tiler for extracting relevant elements + // (M,N) -> tensor coord + auto tiler = product_each(shape(layout_mn)); + +#if 0 + print("thr_layout: "); print(thr_layout); print("\n"); + print("val_layout: "); print(val_layout); print("\n"); + print("layout_mn : "); print(layout_mn); print("\n"); + print("layout_tv : "); print(layout_tv); print("\n"); + print("tiler : "); print(tiler); print("\n"); +#endif + + return Xe2DTiledCopy, decltype(layout_tv), decltype(tiler)>{copy_atom}; +} + +// The number of threads involved in a Xe2DTiledCopy +template +CUTE_HOST_DEVICE constexpr +auto +size(Xe2DTiledCopy const&) +{ + return typename Xe2DTiledCopy::TiledNumThr{}; +} + } // end namespace cute diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index 6d1e6528b..46171de4f 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -120,16 +120,16 @@ class CollectiveEpilogue< using Trait_C = Copy_Traits; using XE_Copy_C = decltype(make_tiled_copy(Copy_Atom{} - .with(static_cast(nullptr), int32_t(0), int32_t(0), int32_t(0)), + .with(static_cast(nullptr), int32_t(0), int32_t(0)), Layout>>{}, - make_layout(make_shape(get<0>(typename Trait_C::Shape_MN{}), - get<1>(typename Trait_C::Shape_MN{}) / Int{})))); + make_layout(make_shape(get<0>(typename Trait_C::BlockShape{}), + get<1>(typename Trait_C::BlockShape{}) / Int{})))); using Trait_D = Copy_Traits; using XE_Copy_D = decltype(make_tiled_copy(Copy_Atom{} - .with(static_cast(nullptr),int32_t(0), int32_t(0), int32_t(0)), + .with(static_cast(nullptr),int32_t(0), int32_t(0)), Layout>>{}, - make_layout(make_shape(get<0>(typename Trait_D::Shape_MN{}), - get<1>(typename Trait_D::Shape_MN{}) / Int{})))); + make_layout(make_shape(get<0>(typename Trait_D::BlockShape{}), + get<1>(typename Trait_D::BlockShape{}) / Int{})))); private: constexpr static bool is_source_supported = not cute::is_void_v; constexpr static bool is_destination_supported = not cute::is_void_v; @@ -188,19 +188,19 @@ class CollectiveEpilogue< XE_Copy_C xe_load_c = {}; if constexpr (is_source_supported) { xe_load_c = make_tiled_copy(Copy_Atom, ElementC>{}.with( - args.ptr_C, N, M, N), + args.ptr_C, M, N), Layout>>{}, - make_layout(make_shape(get<0>(typename Trait_C::Shape_MN{}), - get<1>(typename Trait_C::Shape_MN{}) / Int{}))); + make_layout(make_shape(get<0>(typename Trait_C::BlockShape{}), + get<1>(typename Trait_C::BlockShape{}) / Int{}))); } XE_Copy_D xe_store_d = {}; if constexpr (is_destination_supported) { xe_store_d = make_tiled_copy(Copy_Atom, ElementD>{}.with( - args.ptr_D, N, M, N), + args.ptr_D, M, N), Layout>>{}, - make_layout(make_shape(get<0>(typename Trait_D::Shape_MN{}), - get<1>(typename Trait_D::Shape_MN{}) / Int{}))); + make_layout(make_shape(get<0>(typename Trait_D::BlockShape{}), + get<1>(typename Trait_D::BlockShape{}) / Int{}))); } return { @@ -313,10 +313,9 @@ class CollectiveEpilogue< Tensor trC = make_tensor(Shape>{}); Tensor trD = make_tensor(Shape>{}); - Tensor tOuti = params.xe_store_d.get_pvc_tensor( - make_coord(m_offset, n_offset, 0), - make_shape(_, Int{}, Int{}, L), - make_stride(Int(MmaAtomShape{})>{}, Int(MmaAtomShape{})>{}, _1{})); + Tensor rw_coord = params.xe_store_d.get_pvc_tensor( + make_coord(m_offset, n_offset, l_offset), + make_shape(_, Int{}, Int{})); // Because Sm90 uses shared memory, they are not tied to using the same accumulator values // for MMA and Epilogue. But because we are operating directly in the accumulators, we need to be @@ -332,8 +331,6 @@ class CollectiveEpilogue< Tensor tRS_cD = make_counting_tensor(tRS_cD_mn.layout()); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) - Tensor rw_coord = tOuti(_,_,_,l_coord); - // Get the fusion callbacks // Arguments passed here relate to sub-group tiles, rather than CTA (work-group) tiles constexpr bool RefSrc = true; diff --git a/include/cutlass/epilogue/fusion/xe_visitor.hpp b/include/cutlass/epilogue/fusion/xe_visitor.hpp index 50b77dd02..fe8514ca7 100644 --- a/include/cutlass/epilogue/fusion/xe_visitor.hpp +++ b/include/cutlass/epilogue/fusion/xe_visitor.hpp @@ -68,8 +68,8 @@ struct XeAuxLoad { using XE_Copy_Aux = decltype(make_tiled_copy(Copy_Atom{} .with(static_cast(nullptr), int32_t(0), int32_t(0), int32_t(0)), Layout>{}, - make_layout(make_shape(get<0>(typename Trait_Aux::Shape_MN{}), - get<1>(typename Trait_Aux::Shape_MN{}) / SubgroupSize{})))); + make_layout(make_shape(get<0>(typename Trait_Aux::BlockShape{}), + get<1>(typename Trait_Aux::BlockShape{}) / SubgroupSize{})))); struct Params { XE_Copy_Aux xe_load_aux; Element null_default = Element(0); @@ -88,10 +88,10 @@ struct XeAuxLoad { auto N_AUX = get<0>(args.dAux); // dAux is a stride and N_AUX is a size auto M_AUX = size(M); XE_Copy_Aux xe_load_aux = make_tiled_copy(Copy_Atom{}.with( - args.ptr_aux, N_AUX, M_AUX, N_AUX), + args.ptr_aux, M_AUX, N_AUX), Layout>{}, - make_layout(make_shape(get<0>(typename Trait_Aux::Shape_MN{}), - get<1>(typename Trait_Aux::Shape_MN{}) / SubgroupSize{}))); + make_layout(make_shape(get<0>(typename Trait_Aux::BlockShape{}), + get<1>(typename Trait_Aux::BlockShape{}) / SubgroupSize{}))); bool use_default = false; if constexpr (EnableNullptr) { @@ -205,11 +205,9 @@ struct XeAuxLoad { auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; auto m_offset = m_coord * SG_M; auto n_offset = n_coord * SG_N; - Tensor tOuti = args.tiled_copy.get_pvc_tensor( - make_coord(m_offset, n_offset, 0), - make_shape(_, Int{}, Int{}, L), - make_stride(Int(MmaAtomShape{})>{}, Int(MmaAtomShape{})>{}, _1{})); - Tensor rw_coord = tOuti(_,_,_,l_coord); + Tensor rw_coord = args.tiled_copy.get_pvc_tensor( + make_coord(m_offset, n_offset, l_coord), + make_shape(_, Int{}, Int{})); return ConsumerStoreCallbacks( rw_coord, xe_copy_aux, cute::move(trAux), params_ptr diff --git a/include/cutlass/gemm/collective/xe_mma.hpp b/include/cutlass/gemm/collective/xe_mma.hpp index 77cddf7df..7ab946f14 100644 --- a/include/cutlass/gemm/collective/xe_mma.hpp +++ b/include/cutlass/gemm/collective/xe_mma.hpp @@ -126,23 +126,19 @@ struct CollectiveMma< using PrefetchBTileSize = decltype(ceil_div(Shape, Int>{},PrefetchBThrShape{})); static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); - using traits_load_A = Copy_Traits; + + using traits_load_A = Copy_Traits; using atom_load_A = Copy_Atom; - using XE_Copy_A = decltype(make_tiled_copy(atom_load_A{} - .with(static_cast(nullptr), int32_t(0), int32_t(0), int32_t(0)), - Layout>>{}, - make_layout(make_shape(get<0>(typename traits_load_A::Shape_MN{}), - get<1>(typename traits_load_A::Shape_MN{}) / Int{})))); - using traits_load_B = Copy_Traits; + + using traits_load_B = Copy_Traits; using atom_load_B = Copy_Atom; - using XE_Copy_B = decltype(make_tiled_copy(atom_load_B{} - .with(static_cast(nullptr), int32_t(0), int32_t(0), int32_t(0)), - Layout>>{}, - make_layout(make_shape(get<0>(typename traits_load_B::Shape_MN{}), - get<1>(typename traits_load_B::Shape_MN{}) / Int{})))); using XE_Prefetch_A = decltype(cute::detail::prefetch_selector()); using XE_Prefetch_B = decltype(cute::detail::prefetch_selector()); + + using TensorMKL = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), StrideA{})); //(m, k) + using TensorNKL = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), StrideB{})); //(n, k) + // Host side kernel arguments struct Arguments { ElementA const* ptr_A; @@ -152,10 +148,8 @@ struct CollectiveMma< }; struct Params { - XE_Copy_A gmem_tiled_copy_a; - XE_Copy_B gmem_tiled_copy_b; - XE_Prefetch_A gmem_prefetch_a; - XE_Prefetch_B gmem_prefetch_b; + TensorMKL mA; + TensorNKL mB; }; // @@ -169,20 +163,15 @@ struct CollectiveMma< to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { (void) workspace; - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M,N,K,L] = problem_shape_MNKL; - - XE_Copy_A copyA = make_tiled_copy(Copy_Atom, ElementA>{}.with(args.ptr_A, K, M, K), - Layout>>{}, - make_layout(make_shape(get<0>(typename traits_load_A::Shape_MN{}), - get<1>(typename traits_load_A::Shape_MN{}) / Int{}))); - XE_Copy_B copyB = make_tiled_copy(Copy_Atom, ElementB>{}.with(args.ptr_B, N, K, N), - Layout>>{}, - make_layout(make_shape(get<0>(typename traits_load_B::Shape_MN{}), - get<1>(typename traits_load_B::Shape_MN{}) / Int{}))); - XE_Prefetch_A prefetchA = cute::detail::prefetch_selector((void *)args.ptr_A, K, M, K); - XE_Prefetch_B prefetchB = cute::detail::prefetch_selector((void *)args.ptr_B, N, K, N); - return Params{copyA, copyB, prefetchA, prefetchB}; + auto [M,N,K,L] = problem_shape; + + auto mA_mkl = make_tensor(make_gmem_ptr(static_cast(args.ptr_A)), + make_layout(make_shape(M, K, L), args.dA)); + + auto mB_nkl = make_tensor(make_gmem_ptr(static_cast(args.ptr_B)), + make_layout(make_shape(N, K, L), args.dB)); + + return Params{mA_mkl, mB_nkl}; } /// Perform a subgroup-scoped matrix multiply-accumulate @@ -206,7 +195,7 @@ struct CollectiveMma< KTileIterator k_tile_iter, int k_tile_count, ResidueMNK residue_mnk, BlkCoord const &blk_coord, - int const &K, + int const &K_start, int thread_idx, char *smem_buf, Params const& mainloop) @@ -218,35 +207,43 @@ struct CollectiveMma< (void)thread_idx; (void)smem_buf; - // Instantiate the MMA object - TiledMma tiled_mma; - auto thread_mma = tiled_mma.get_slice(thread_idx); - Tensor tCrA_partition = thread_mma.partition_fragment_A(gA(_, _, 0)); - Tensor tCrA = make_tensor(static_cast(tCrA_partition).data(), - tCrA_partition.shape()); - Tensor tCrB_partition = thread_mma.partition_fragment_B(gB(_, _, 0)); - Tensor tCrB = make_tensor(static_cast(tCrB_partition).data(), - make_shape(size<0>(tCrB_partition.shape()), - size<2>(tCrB_partition.shape()), - size<1>(tCrB_partition.shape()))); + auto tiled_copy_a = make_xe_2d_copy(atom_load_A{}.with(mainloop.mA), + Layout>>{}); + auto tiled_copy_b = make_xe_2d_copy(atom_load_B{}.with(mainloop.mB), + Layout>>{}); + // Partition the copying of A and B tiles across the threads - auto gmem_thr_copy_A = mainloop.gmem_tiled_copy_a.get_slice(thread_idx); - auto gmem_thr_copy_B = mainloop.gmem_tiled_copy_b.get_slice(thread_idx); + auto thr_copy_A = tiled_copy_a.get_slice(thread_idx); + auto thr_copy_B = tiled_copy_b.get_slice(thread_idx); + + // Instantiate the MMA object and get thread slice + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_slice(thread_idx); + + // Partition fragment + Tensor fragment_A = thr_mma.partition_fragment_A(gA(_, _, 0)); + Tensor fragment_B = thr_mma.partition_fragment_B(gB(_, _, 0)); + + // Retile for copy + Tensor copy_tCrA = thr_copy_A.retile_D(fragment_A); + Tensor copy_tCrB = thr_copy_B.retile_D(fragment_B); + + // Retile for cute::gemm + Tensor mma_tCrA = thr_copy_A.retile_MMA(thr_mma, fragment_A); + Tensor mma_tCrB = thr_copy_B.retile_MMA(thr_mma, fragment_B); - auto tCrA_copy_view = gmem_thr_copy_A.retile_D(tCrA); - auto tCrB_copy_view = gmem_thr_copy_B.retile_D(tCrB); #if CUTLASS_ENABLE_DEBUG_PRINTS if (thread(LOG_THREAD, LOG_GROUP)) { print("======================= A: \n"); print(" gA : "); print(gA); print("\n"); - print("tCrA_copy_view : "); print(tCrA_copy_view); print("\n"); - print(" tCrA : "); print(tCrA); print("\n"); + print("copy_tCrA : "); print(copy_tCrA); print("\n"); + print(" mma_tCrA : "); print(mma_tCrA); print("\n"); print("===================== B :\n"); print(" gB : "); print(gB); print("\n"); - print("tCrB_copy_view : "); print(tCrB_copy_view); print("\n"); - print(" tCrB : "); print(tCrB); print("\n"); + print("copy_tCrB : "); print(copy_tCrB); print("\n"); + print(" mma_tCrB : "); print(mma_tCrB); print("\n"); print("===================== Config: \n"); print(" threads per workgroup : "); print(MaxThreadsPerBlock); print("\n"); @@ -257,7 +254,7 @@ struct CollectiveMma< print(" PrefetchATileSize : ");print(PrefetchATileSize{});print("\n"); print(" PrefetchBTileSize : ");print(PrefetchBTileSize{});print("\n"); } - #endif + #endif // // Mainloop @@ -271,55 +268,56 @@ struct CollectiveMma< const int n_coord = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N; #endif const int l_coord = l_idx; - Tensor iter_a = mainloop.gmem_tiled_copy_a.get_pvc_tensor( - make_coord(m_coord, 0, l_coord), append<4>(tCrA_copy_view.shape(), k_tile_count), - append<3>(typename XE_Copy_A::Shape_MN{}, BLK_K), seq<0,1,1>{}); - Tensor iter_b = mainloop.gmem_tiled_copy_b.get_pvc_tensor( - make_coord(0, n_coord, l_coord), append<4>(tCrB_copy_view.shape(), k_tile_count), - append<3>(typename XE_Copy_B::Shape_MN{}, BLK_K), seq<0,1,0>{}); - - const int k_start_idx = crd2idx((*k_tile_iter), make_shape(K)); + + Tensor block2d_copy_iter_a = tiled_copy_a.get_pvc_tensor(make_coord(m_coord, 0, l_coord), copy_tCrA.shape()); + auto copy_iter_a = append_pvc_tensor<1>(block2d_copy_iter_a, k_tile_count, BLK_K); + + Tensor block2d_copy_iter_b = tiled_copy_b.get_pvc_tensor(make_coord(n_coord, 0, l_coord), copy_tCrB.shape()); + auto copy_iter_b = append_pvc_tensor<1>(block2d_copy_iter_b, k_tile_count, BLK_K); + + const int k_start_idx = crd2idx((*k_tile_iter), make_shape(K_start)); int prefetch_k = 0; - Tensor prefetch_iter_a = mainloop.gmem_prefetch_a.get_pvc_tensor( - make_coord(m_coord + (get_sub_group_id() % ATOM_N) / get<1>(PrefetchAThrShape{}) * get<0>(PrefetchATileSize{}), - (k_start_idx + (get_sub_group_id() % ATOM_N) % get<1>(PrefetchAThrShape{})) * PrefetchStrideA, l_coord), - append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count), - append<3>(make_shape(SG_M, SG_K), BLK_K), seq<0, 1, 1>{}); - Tensor prefetch_iter_b = mainloop.gmem_prefetch_b.get_pvc_tensor( - make_coord(((get_sub_group_id() / ATOM_N) / get<1>(PrefetchBThrShape{}) + k_start_idx) * PrefetchStrideB, - n_coord + (get_sub_group_id() / ATOM_N) % get<1>(PrefetchBThrShape{}) * get<1>(PrefetchBTileSize{}), l_coord), - append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count), - append<3>(make_shape(SG_K, SG_N), BLK_K), seq<0,1,0>{}); + Tensor block2d_prefetch_iter_a = XE_Prefetch_A{}.get_pvc_tensor( + make_coord(m_coord + (get_sub_group_id() % ATOM_N) / get<1>(PrefetchAThrShape{}) * get<0>(PrefetchATileSize{}), + (k_start_idx + (get_sub_group_id() % ATOM_N) % get<1>(PrefetchAThrShape{})) * PrefetchStrideA, + l_coord), + make_shape(_1{}, _1{}, _1{})); + auto prefetch_iter_a = append_pvc_tensor<1>(block2d_prefetch_iter_a, k_tile_count, BLK_K); + + Tensor block2d_prefetch_iter_b = XE_Prefetch_B{}.get_pvc_tensor( + make_coord((get_sub_group_id() / ATOM_N / get<1>(PrefetchBThrShape{}) + k_start_idx) * PrefetchStrideB, + n_coord + (get_sub_group_id() / ATOM_N) % get<1>(PrefetchBThrShape{}) * get<1>(PrefetchBTileSize{}), + l_coord), + make_shape(_1{}, _1{}, _1{})); + auto prefetch_iter_b = append_pvc_tensor<0>(block2d_prefetch_iter_b, k_tile_count, BLK_K); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < DispatchPolicy::Stages; i++, prefetch_k++) { if constexpr(cute::detail::has_prefetch) { - prefetch(mainloop.gmem_tiled_copy_a, prefetch_iter_a(_,_,_,prefetch_k)); + prefetch(tiled_copy_a, prefetch_iter_a(_,_,_,prefetch_k)); } if constexpr(cute::detail::has_prefetch) { - prefetch(mainloop.gmem_tiled_copy_b, prefetch_iter_b(_,_,_,prefetch_k)); + prefetch(tiled_copy_b, prefetch_iter_b(_,_,_,prefetch_k)); } } CUTLASS_PRAGMA_UNROLL for (int k_tile = 0, k = k_start_idx; k_tile < k_tile_count; ++k_tile, ++k, ++prefetch_k) { // Copy gmem to rmem for the first k_tile - copy(mainloop.gmem_tiled_copy_a, iter_a(_,_,_,k), tCrA_copy_view); - copy(mainloop.gmem_tiled_copy_b, iter_b(_,_,_,k), tCrB_copy_view); + copy(tiled_copy_a, copy_iter_a(_,_,_,k), copy_tCrA); + copy(tiled_copy_b, copy_iter_b(_,_,_,k), copy_tCrB); if(prefetch_k < k_tile_count) { if constexpr(cute::detail::has_prefetch) { - prefetch(mainloop.gmem_tiled_copy_a, prefetch_iter_a(_,_,_,prefetch_k)); + prefetch(tiled_copy_a, prefetch_iter_a(_,_,_,prefetch_k)); } if constexpr(cute::detail::has_prefetch) { - prefetch(mainloop.gmem_tiled_copy_b, prefetch_iter_b(_,_,_,prefetch_k)); + prefetch(tiled_copy_b, prefetch_iter_b(_,_,_,prefetch_k)); } } - for (int i = 0; i < SG_K / SubgroupSize; i++) { - cute::gemm(tiled_mma, accum, tCrA(_, _, i), tCrB(_, i, _), src_accum); - } + cute::gemm(tiled_mma, mma_tCrA, mma_tCrB, accum); } } }; diff --git a/include/cutlass/gemm/kernel/xe_gemm.hpp b/include/cutlass/gemm/kernel/xe_gemm.hpp index 64e5b76c6..ca4810d83 100644 --- a/include/cutlass/gemm/kernel/xe_gemm.hpp +++ b/include/cutlass/gemm/kernel/xe_gemm.hpp @@ -110,6 +110,12 @@ class GemmUniversal< static constexpr int PrefetchStrideA = static_cast(get<1>(PrefetchATileSize{})); static constexpr int PrefetchStrideB = static_cast(get<0>(PrefetchBTileSize{})); + using TensorMKL = typename CollectiveMainloop::TensorMKL; + using TensorNKL = typename CollectiveMainloop::TensorNKL; + + using TensorMK = decltype(TensorMKL{}(_, _, 0)); + using TensorNK = decltype(TensorNKL{}(_, _, 0)); + // Kernel level shared memory storage struct SharedStorage { using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; @@ -130,6 +136,8 @@ class GemmUniversal< struct Params { GemmUniversalMode mode; ProblemShape problem_shape; + TensorMK mA_mk; + TensorNK mB_nk; MainloopParams mainloop; EpilogueParams epilogue; }; @@ -143,10 +151,19 @@ class GemmUniversal< Params to_underlying_arguments(Arguments const& args, void* workspace) { (void) workspace; + + auto mainloop_args = CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace); + + auto l_coord = BlockIdxZ(); + Tensor mA_mk = mainloop_args.mA(_,_,l_coord); + Tensor mB_nk = mainloop_args.mB(_,_,l_coord); + return { args.mode, args.problem_shape, - CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + mA_mk, + mB_nk, + mainloop_args, CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace) }; } @@ -234,18 +251,14 @@ class GemmUniversal< auto n_coord = BlockIdxX(); #endif auto l_coord = BlockIdxZ(); - auto blk_coord_mnkl = make_coord(m_coord, n_coord, _, l_coord); + + auto blk_coord_mnkl = make_coord(m_coord, n_coord, _, l_coord); int sub_group_id = thread_idx / SubgroupSize; constexpr auto workgroup_shape = WorkgroupTileShape{}; // (SUB_M,SUB_N,SUB_K) constexpr auto subgroup_shape = SubgroupTileShape{}; - Tensor mA_mkl = make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(M,K,L), StrideA{}); //(m,k,l) - Tensor mB_nkl = make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(N,K,L), StrideB{}); //(n,k,l) - Tensor mA_mk = mA_mkl(_,_,l_coord); // (m,k) - Tensor mB_nk = mB_nkl(_,_,l_coord); // (n,k) - - auto gA = local_tile(mA_mk, blk_shape, take<0, 3>(blk_coord_mnkl), Step<_1, X, _1>{}); - auto gB = local_tile(mB_nk, blk_shape, take<0, 3>(blk_coord_mnkl), Step< X, _1, _1>{}); + auto gA = local_tile(params.mA_mk, blk_shape, take<0, 3>(blk_coord_mnkl), Step<_1, X, _1>{}); + auto gB = local_tile(params.mB_nk, blk_shape, take<0, 3>(blk_coord_mnkl), Step< X, _1, _1>{}); // Compute tile residues for predication auto m_max_coord = M - get<0>(subgroup_shape) * m_coord; // M - SUB_M * m_coord diff --git a/test/unit/cute/intel_xe/CMakeLists.txt b/test/unit/cute/intel_xe/CMakeLists.txt index 807c84743..a13818668 100755 --- a/test/unit/cute/intel_xe/CMakeLists.txt +++ b/test/unit/cute/intel_xe/CMakeLists.txt @@ -37,9 +37,8 @@ cutlass_test_unit_add_executable( gemm_partition_src_dst.cpp gemm_partition_fragment_abc.cpp gemm_tiled_copy_abc.cpp - gemm_row_col.cpp - gemm_col_row.cpp - gemm_col_col.cpp + gemm_layout.cpp + gemm_data_type.cpp ) else() cutlass_test_unit_add_executable( diff --git a/test/unit/cute/intel_xe/copy_1d.cpp b/test/unit/cute/intel_xe/copy_1d.cpp index 4ef31c991..3db8b090f 100644 --- a/test/unit/cute/intel_xe/copy_1d.cpp +++ b/test/unit/cute/intel_xe/copy_1d.cpp @@ -29,6 +29,8 @@ * **************************************************************************************************/ +#include "cutlass/detail/layout.hpp" + #include #include #include @@ -82,26 +84,24 @@ void copy_kernel_vectorized(TensorS tile_S, TensorD tile_D) { make_shape(_1{}, Int{}), Stride, _1>{}); auto ThreadLayout = make_layout(make_shape(_1{}, _16{})); - auto tiled_copy_load = make_tiled_copy(Atom_load{}, // access size + auto tiled_copy_load = make_xe_2d_copy(Atom_load{}, // access size ThreadLayout, // thread layout VecLayout); // vector layout (e.g. 4x1) auto tiled_copy_store = - make_tiled_copy(Atom_store{}, // access size + make_xe_2d_copy(Atom_store{}, // access size ThreadLayout, // thread layout VecLayout); // vector layout (e.g. 4x1) - auto tiled_ldsm = make_tiled_copy(Atom_ldsm{}, // access size + auto tiled_ldsm = make_xe_2d_copy(Atom_ldsm{}, // access size ThreadLayout, // thread layout VecLayout); // vector layout (e.g. 4x1) - auto tiled_stsm = make_tiled_copy(Atom_stsm{}, // access size + auto tiled_stsm = make_xe_2d_copy(Atom_stsm{}, // access size ThreadLayout, // thread layout VecLayout); // vector layout (e.g. 4x1) // Construct a Tensor corresponding to each thread's slice. - auto thr_copy_load = - tiled_copy_load.get_thread_slice(ThreadIdxX()); - auto thr_copy_store = - tiled_copy_store.get_thread_slice(ThreadIdxX()); + auto thr_copy_load = tiled_copy_load.get_thread_slice(ThreadIdxX()); + auto thr_copy_store = tiled_copy_store.get_thread_slice(ThreadIdxX()); auto thr_copy_ldsm = tiled_ldsm.get_thread_slice(ThreadIdxX()); auto thr_copy_stsm = tiled_stsm.get_thread_slice(ThreadIdxX()); @@ -172,7 +172,6 @@ TEST(PVC_1d_copy, copy_double) { cutlass::device_vector device_src = host_src; cutlass::device_vector device_output(M * N); - Tensor S = make_tensor(make_gmem_ptr(device_src.data()), make_layout(Shape, Int>{}, Stride, _1>{})); diff --git a/test/unit/cute/intel_xe/copy_block.cpp b/test/unit/cute/intel_xe/copy_block.cpp index 3eed445c0..d3088e02c 100644 --- a/test/unit/cute/intel_xe/copy_block.cpp +++ b/test/unit/cute/intel_xe/copy_block.cpp @@ -29,6 +29,8 @@ * **************************************************************************************************/ +#include "cutlass/detail/layout.hpp" + #include #include #include @@ -41,7 +43,8 @@ using namespace syclcompat::experimental; #define SUBGROUP_SIZE (16) -template +template void copy_kernel_vectorized(TensorS S, TensorD D, TiledLoad load, TiledStore store) { const int m_coord = 0; @@ -53,9 +56,9 @@ void copy_kernel_vectorized(TensorS S, TensorD D, TiledLoad load, auto thr_tile_load_D = thr_copy_load.partition_D(S); auto fragment = make_fragment_like(thr_tile_load_D); auto ld_tensor = - load.get_pvc_tensor(make_coord(m_coord, n_coord, l_coord), - fragment.shape(), typename TiledLoad::Shape_MN{}); - if constexpr (cute::detail::has_prefetch) prefetch(load, ld_tensor); + load.get_pvc_tensor(make_coord(m_coord, n_coord, l_coord), fragment.shape()); + if constexpr (cute::detail::has_prefetch) + prefetch(load, ld_tensor); copy(load, ld_tensor, fragment); // ========== store ========== @@ -63,9 +66,8 @@ void copy_kernel_vectorized(TensorS S, TensorD D, TiledLoad load, Tensor frag_view = make_tensor(static_cast(fragment).data(), thr_copy_store.partition_S(D).shape()); - auto st_tensor = store.get_pvc_tensor(make_coord(m_coord, n_coord, l_coord), - frag_view.shape(), - typename TiledStore::Shape_MN{}); + auto st_tensor = + store.get_pvc_tensor(make_coord(m_coord, n_coord, l_coord), frag_view.shape()); copy(store, frag_view, st_tensor); #if 0 @@ -88,10 +90,11 @@ void copy_kernel_vectorized(TensorS S, TensorD D, TiledLoad load, } #endif } -template +template struct copy_op; -template +template struct copy_op { void operator()() { // @@ -99,7 +102,7 @@ struct copy_op { // cutlass::host_vector host_src(M * N); cutlass::host_vector host_output(M * N); - + for (size_t i = 0; i < host_src.size(); ++i) { host_src[i] = static_cast(i); } @@ -110,29 +113,28 @@ struct copy_op { Tensor S = make_tensor(make_gmem_ptr(device_src.data()), make_layout(Shape, Int>{}, Stride, _1>{})); - Tensor D = + Tensor D = make_tensor(make_gmem_ptr(device_output.data()), - make_layout(Shape, Int>{}, Stride, _1>{})); - - auto tiled_load = - make_tiled_copy( - Copy_Atom, dtype>{}.with(device_src.data(), N, M, - N), - Layout>, Stride<_0, _1>>{}, - Layout(typename Copy_Traits::Shape_MN{})), _1>, Stride<_1, _0>>{}); - auto tiled_store = make_tiled_copy( - Copy_Atom, dtype>{}.with(device_output.data(), N, - M, N), - Layout>, Stride<_0, _1>>{}, - Layout(typename Copy_Traits::Shape_MN{})), _1>, Stride<_1, _0>>{}); + make_layout(Shape, Int>{}, Stride, _1>{})); + + auto tiled_load = make_xe_2d_copy( + Copy_Atom, dtype>{}.with(device_src.data(), M, N), + Layout>>{}); + + auto tiled_store = make_xe_2d_copy( + Copy_Atom, dtype>{}.with(device_output.data(), M, N), + Layout>>{}); + auto blockDim = syclcompat::dim3(size(tiled_load)); // // Launch the kernel // - launch>( - launch_policy{syclcompat::dim3(1), blockDim, - kernel_properties{sycl_exp::sub_group_size}}, + launch_policy{ + syclcompat::dim3(1), blockDim, + kernel_properties{sycl_exp::sub_group_size}}, S, D, tiled_load, tiled_store); syclcompat::wait_and_throw(); @@ -143,7 +145,7 @@ struct copy_op { } }; -template +template struct copy_op { void operator()() { // @@ -152,7 +154,7 @@ struct copy_op { using dtype = char; cutlass::host_vector host_src(M * N); cutlass::host_vector host_output(M * N); - + for (size_t i = 0; i < host_src.size(); ++i) { host_src[i] = static_cast(i); } @@ -167,24 +169,22 @@ struct copy_op { make_tensor(make_gmem_ptr(device_output.data()), make_layout(Shape, Int>{}, Stride, _1>{})); - auto tiled_load = make_tiled_copy( - Copy_Atom, dtype>{}.with(device_src.data(), N, M, - N), - Layout, Stride<_0, _1>>{}, - make_layout(shape<1>(typename Copy_Atom, dtype>::ValLayoutDst{}))); - auto tiled_store = make_tiled_copy( - Copy_Atom, dtype>{}.with(device_output.data(), N, M, - N), - Layout, Stride<_0, _1>>{}, - Layout, Stride<_2, _1>>{}); + auto tiled_load = make_xe_2d_copy( + Copy_Atom, dtype>{}.with(S), Layout>{}); + + auto tiled_store = make_xe_2d_copy( + Copy_Atom, dtype>{}.with(D), Layout>{}); + auto blockDim = syclcompat::dim3(size(tiled_load)); // // Launch the kernel // - launch>( - launch_policy{syclcompat::dim3(1), blockDim, - kernel_properties{sycl_exp::sub_group_size}}, + launch_policy{ + syclcompat::dim3(1), blockDim, + kernel_properties{sycl_exp::sub_group_size}}, S, D, tiled_load, tiled_store); syclcompat::wait_and_throw(); @@ -195,8 +195,8 @@ struct copy_op { } }; -template -struct copy_op{ +template +struct copy_op { void operator()() { // // Allocate and initialize @@ -204,7 +204,7 @@ struct copy_op{ using dtype = uint16_t; cutlass::host_vector host_src(M * N); cutlass::host_vector host_output(M * N); - + for (size_t i = 0; i < host_src.size(); ++i) { host_src[i] = static_cast(i); } @@ -215,28 +215,26 @@ struct copy_op{ Tensor S = make_tensor(make_gmem_ptr(device_src.data()), make_layout(Shape, Int>{}, Stride, _1>{})); - Tensor D = - make_tensor(make_gmem_ptr(device_output.data()), - make_layout(Shape, Int>{}, Stride, _1>{})); - - auto tiled_load = make_tiled_copy( - Copy_Atom, dtype>{}.with(device_src.data(), N, M, - N), - Layout>, Stride<_0, _1>>{}, - Layout, _2>, Stride<_1, _2>>{}); - auto tiled_store = make_tiled_copy( - Copy_Atom, uint16_t>{}.with(device_output.data(), N / 2, - M * 2, N / 2), - Layout, Stride<_0, _1>>{}, - Layout, Stride<_1, _0>>{}); + Tensor D = make_tensor( + make_gmem_ptr(device_output.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + + auto tiled_load = make_xe_2d_copy( + Copy_Atom, dtype>{}.with(device_src.data(), M, N), + Layout>>{}); + auto tiled_store = make_xe_2d_copy( + Copy_Atom, uint16_t>{}.with( + device_output.data(), M * 2, N / 2), Layout>{}); auto blockDim = syclcompat::dim3(size(tiled_load)); // // Launch the kernel // - launch>( - launch_policy{syclcompat::dim3(1), blockDim, - kernel_properties{sycl_exp::sub_group_size}}, + launch_policy{ + syclcompat::dim3(1), blockDim, + kernel_properties{sycl_exp::sub_group_size}}, S, D, tiled_load, tiled_store); syclcompat::wait_and_throw(); @@ -247,10 +245,10 @@ struct copy_op{ host_src[(i % M) * N + j + (i / M) * N / 2]); } } - } + } }; -template +template struct copy_op { void operator()() { // @@ -259,7 +257,7 @@ struct copy_op { using dtype = uint32_t; cutlass::host_vector host_src(M * N); cutlass::host_vector host_output(M * N); - + for (size_t i = 0; i < host_src.size(); ++i) { host_src[i] = static_cast(i); } @@ -274,25 +272,22 @@ struct copy_op { make_tensor(make_gmem_ptr(device_output.data()), make_layout(Shape, Int>{}, Stride, _1>{})); - auto tiled_load = - make_tiled_copy( - Copy_Atom, dtype>{}.with(device_src.data(), N, M, - N), - Layout, _1>, Stride<_1, _0>>{}, - Layout(typename Copy_Traits::Shape_MN{}))>, Stride<_0, _1>>{}); - auto tiled_store = make_tiled_copy( - Copy_Atom, dtype>{}.with(device_output.data(), M, N, - M), - Layout>, Stride<_0, _1>>{}, - Layout(typename Copy_Traits::Shape_MN{})), _1>, Stride<_1, _0>>{}); + auto tiled_load = make_xe_2d_copy( + Copy_Atom, dtype>{}.with(device_src.data(), M, N), + Layout, _1>>{}); + auto tiled_store = make_xe_2d_copy( + Copy_Atom, dtype>{}.with(device_output.data(), N, M), + Layout>>{}); auto blockDim = syclcompat::dim3(size(tiled_load)); // // Launch the kernel // - launch>( - launch_policy{syclcompat::dim3(1), blockDim, - kernel_properties{sycl_exp::sub_group_size}}, + launch_policy{ + syclcompat::dim3(1), blockDim, + kernel_properties{sycl_exp::sub_group_size}}, S, D, tiled_load, tiled_store); syclcompat::wait_and_throw(); diff --git a/test/unit/cute/intel_xe/copy_scatter.cpp b/test/unit/cute/intel_xe/copy_scatter.cpp index 4373e884c..dce587f15 100644 --- a/test/unit/cute/intel_xe/copy_scatter.cpp +++ b/test/unit/cute/intel_xe/copy_scatter.cpp @@ -29,6 +29,8 @@ * **************************************************************************************************/ +#include "cutlass/detail/layout.hpp" + #include #include #include @@ -42,7 +44,8 @@ using namespace syclcompat::experimental; #define SUBGROUP_SIZE (16) template -void copy_kernel_global(TensorS S, TensorD D, TiledLoad load, TiledStore store) { +void copy_kernel_global(TensorS S, TensorD D, TiledLoad load, + TiledStore store) { auto thr_copy_load = load.get_thread_slice(ThreadIdxX()); Tensor thr_tile_load_S = thr_copy_load.partition_S(S); @@ -117,9 +120,10 @@ TEST(PVC_2d_copy, load_store_global) { make_tensor(make_gmem_ptr(device_output.data()), make_layout(Shape, Int>{}, Stride, _1>{})); - auto tiled_copy = make_tiled_copy(Copy_Atom, Element>{}, - Layout, Stride<_16, _1>>{}, - Layout, Stride<_1, _8>>{}); + auto tiled_copy = + make_tiled_copy(Copy_Atom, Element>{}, + Layout, Stride<_16, _1>>{}, + Layout, Stride<_1, _8>>{}); static constexpr auto subgroup_size = 16; auto blockDim = syclcompat::dim3(size(tiled_copy)); // @@ -165,9 +169,10 @@ TEST(PVC_2d_copy, load_store_global_V) { make_tensor(make_gmem_ptr(device_output.data()), make_layout(Shape, Int>{}, Stride, _1>{})); - auto tiled_copy = make_tiled_copy(Copy_Atom, Element>{}, - Layout, Stride<_16, _1>>{}, - Layout, Stride<_1, _8>>{}); + auto tiled_copy = + make_tiled_copy(Copy_Atom, Element>{}, + Layout, Stride<_16, _1>>{}, + Layout, Stride<_1, _8>>{}); static constexpr auto subgroup_size = 16; auto blockDim = syclcompat::dim3(size(tiled_copy)); // @@ -240,9 +245,10 @@ TEST(PVC_2d_copy, load_store_local) { make_tensor(make_gmem_ptr(device_output.data()), make_layout(Shape, Int>{}, Stride, _1>{})); - auto tiled_copy = make_tiled_copy(Copy_Atom, Element>{}, - Layout, Stride<_16, _1>>{}, - Layout, Stride<_1, _8>>{}); + auto tiled_copy = + make_tiled_copy(Copy_Atom, Element>{}, + Layout, Stride<_16, _1>>{}, + Layout, Stride<_1, _8>>{}); static constexpr auto subgroup_size = 16; auto blockDim = syclcompat::dim3(size(tiled_copy)); // @@ -263,7 +269,8 @@ TEST(PVC_2d_copy, load_store_local) { } template -void copy_kernel_atomic(TensorS S, TensorD D, TiledLoad load, TiledStore store) { +void copy_kernel_atomic(TensorS S, TensorD D, TiledLoad load, + TiledStore store) { auto thr_copy_load = load.get_thread_slice(ThreadIdxX()); Tensor thr_tile_load_S = thr_copy_load.partition_S(S); @@ -331,9 +338,10 @@ TEST(PVC_2d_copy, load_store_stomic_float) { make_tensor(make_gmem_ptr(device_output.data()), make_layout(Shape, Int>{}, Stride, _1>{})); - auto tiled_load = make_tiled_copy(Copy_Atom, Element>{}, - Layout, Stride<_16, _1>>{}, - Layout, Stride<_1, _8>>{}); + auto tiled_load = + make_tiled_copy(Copy_Atom, Element>{}, + Layout, Stride<_16, _1>>{}, + Layout, Stride<_1, _8>>{}); auto tiled_atom = make_tiled_copy(Copy_Atom, Element>{}, Layout, Stride<_16, _1>>{}, Layout, Stride<_1, _8>>{}); @@ -382,9 +390,10 @@ TEST(PVC_2d_copy, load_store_stomic_int) { make_tensor(make_gmem_ptr(device_output.data()), make_layout(Shape, Int>{}, Stride, _1>{})); - auto tiled_load = make_tiled_copy(Copy_Atom, Element>{}, - Layout, Stride<_16, _1>>{}, - Layout, Stride<_1, _8>>{}); + auto tiled_load = + make_tiled_copy(Copy_Atom, Element>{}, + Layout, Stride<_16, _1>>{}, + Layout, Stride<_1, _8>>{}); auto tiled_atom = make_tiled_copy(Copy_Atom, Element>{}, Layout, Stride<_16, _1>>{}, Layout, Stride<_1, _8>>{}); diff --git a/test/unit/cute/intel_xe/copy_subgroup_block.cpp b/test/unit/cute/intel_xe/copy_subgroup_block.cpp index 1cafa48c5..844e7b8d6 100644 --- a/test/unit/cute/intel_xe/copy_subgroup_block.cpp +++ b/test/unit/cute/intel_xe/copy_subgroup_block.cpp @@ -29,6 +29,8 @@ * **************************************************************************************************/ +#include "cutlass/detail/layout.hpp" + #include #include #include @@ -52,12 +54,8 @@ void copy_kernel_vectorized(TensorS S, TensorD D, uint32_t M, uint32_t N) { D, Shape, Int>{}); // ((M, N), m', n') // Slice work group. - Tensor tile_wg_S = - tiled_tensor_S(make_coord(_, _), BlockIdxX(), - BlockIdxY()); - Tensor tile_wg_D = - tiled_tensor_D(make_coord(_, _), BlockIdxX(), - BlockIdxY()); + Tensor tile_wg_S = tiled_tensor_S(make_coord(_, _), BlockIdxX(), BlockIdxY()); + Tensor tile_wg_D = tiled_tensor_D(make_coord(_, _), BlockIdxX(), BlockIdxY()); // Slice subgroup. auto SubgroupShape = Shape, Int>{}; @@ -77,14 +75,10 @@ void copy_kernel_vectorized(TensorS S, TensorD D, uint32_t M, uint32_t N) { } #endif - using traits_load = Copy_Traits; + using traits_load = Copy_Traits; using Atom_load = Copy_Atom; - auto VecLayout = make_layout( - make_shape(get<0>(typename traits_load::Shape_MN{}), - get<1>(typename traits_load::Shape_MN{}) / _16{}), - Stride<_1, _0>{}); - auto tiled_copy_load = make_tiled_copy(Atom_load{}.with(&*S.data(), N, M, N), - Layout>{}, VecLayout); + auto tiled_copy_load = make_xe_2d_copy(Atom_load{}.with(S), + Layout>{}); // Construct a Tensor corresponding to each thread's slice. auto thr_copy_load = @@ -114,26 +108,23 @@ void copy_kernel_vectorized(TensorS S, TensorD D, uint32_t M, uint32_t N) { #endif static constexpr auto sg_per_wg_x = wg_tile_n / sg_tile_n; - const int m_coord = - BlockIdxX() * wg_tile_m + (cutlass::get_sub_group_id() / sg_per_wg_x) * sg_tile_m; - const int n_coord = - BlockIdxY() * wg_tile_n + (cutlass::get_sub_group_id() % sg_per_wg_x) * sg_tile_n; + const int m_coord = BlockIdxX() * wg_tile_m + + (cutlass::get_sub_group_id() / sg_per_wg_x) * sg_tile_m; + const int n_coord = BlockIdxY() * wg_tile_n + + (cutlass::get_sub_group_id() % sg_per_wg_x) * sg_tile_n; const int l_coord = BlockIdxZ(); // Copy from GMEM to RMEM and from RMEM to GMEM - auto blk_load_S = tiled_copy_load.get_pvc_tensor( - make_coord(m_coord, n_coord, l_coord), fragment.shape(), - typename traits_load::Shape_MN{}); + auto blk_load_S = tiled_copy_load.get_pvc_tensor(make_coord(m_coord, n_coord, l_coord), + fragment.shape()); copy(tiled_copy_load, blk_load_S, fragment); - using traits_store = Copy_Traits; + using traits_store = Copy_Traits; using Atom_store = Copy_Atom; auto tiled_copy_store = - make_tiled_copy(Atom_store{}.with(&*D.data(), N, M, N), - Layout, Stride<_0, _1>>{}, VecLayout); - auto thr_copy_store = - tiled_copy_store.get_thread_slice(ThreadIdxX()); + make_xe_2d_copy(Atom_store{}.with(D), Layout>{}); + auto thr_copy_store = tiled_copy_store.get_thread_slice(ThreadIdxX()); Tensor thr_tile_store_D = thr_copy_store.partition_D(tile_sg_D); @@ -150,9 +141,8 @@ void copy_kernel_vectorized(TensorS S, TensorD D, uint32_t M, uint32_t N) { } #endif - auto blk_store_D = tiled_copy_store.get_pvc_tensor( - make_coord(m_coord, n_coord, l_coord), fragment.shape(), - typename traits_store::Shape_MN{}); + auto blk_store_D = tiled_copy_store.get_pvc_tensor(make_coord(m_coord, n_coord, l_coord), + fragment.shape()); // onlt run first subgroup if (syclcompat::global_id::x() < 16 && !syclcompat::global_id::y() && @@ -186,7 +176,6 @@ bool copy(uint32_t M, uint32_t N) { cutlass::device_vector device_src = host_src; cutlass::device_vector device_output = host_output; - // // Make tensors // diff --git a/test/unit/cute/intel_xe/gemm_col_col.cpp b/test/unit/cute/intel_xe/gemm_col_col.cpp deleted file mode 100644 index d9a0c4c5b..000000000 --- a/test/unit/cute/intel_xe/gemm_col_col.cpp +++ /dev/null @@ -1,237 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#include "gemm_utils.hpp" - -template -struct gemm_device_col_col { - using TA = dtype_a; - using TB = dtype_b; - using TC = dtype_c; - - static constexpr bool is_a_row_major = false; - static constexpr bool is_b_row_major = false; - - static constexpr uint32_t wg_tile_m = wg_m; - static constexpr uint32_t wg_tile_n = wg_n; - static constexpr uint32_t sg_tile_m = sg_m; - static constexpr uint32_t sg_tile_n = sg_n; - static constexpr uint32_t sg_tile_k = sg_k; - - static void func(TA const *A, TB const *B, TC *C, uint32_t m, uint32_t n, - uint32_t k) { - - // Represent the full tensors - Tensor mA = make_tensor(make_gmem_ptr(A), - make_layout(make_shape(m, k), make_stride(1, m))); - Tensor mB = make_tensor(make_gmem_ptr(B), - make_layout(make_shape(k, n), make_stride(1, k))); - Tensor mC = make_tensor(make_gmem_ptr(C), - make_layout(make_shape(m, n), make_stride(n, 1))); - - // Get the appropriate blocks for this thread block - auto cta_coord = make_coord(BlockIdxX(), - BlockIdxY(), _); - - auto cta_tiler = - make_shape(Int{}, Int{}, Int{}); - Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); - Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step{}); - Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); - - using traits_load_A = Copy_Traits; - using atom_load_A = Copy_Atom; - TiledCopy copy_a = make_tiled_copy( - atom_load_A{}.with(A, m, k, m), Layout, _1>>{}, - make_layout(make_shape(get<1>(typename traits_load_A::Shape_MN{})/ Int{}, - get<0>(typename traits_load_A::Shape_MN{})))); - - using traits_load_B = Copy_Traits; - using atom_load_B = Copy_Atom; - TiledCopy copy_b = make_tiled_copy( - atom_load_B{}.with(B, k, n, k), Layout, _1>>{}, - make_layout(make_shape(get<1>(typename traits_load_B::Shape_MN{})/ Int{}, - get<0>(typename traits_load_B::Shape_MN{})))); - - using traits_store_C = Copy_Traits; - using atom_store_C = Copy_Atom; - TiledCopy copy_c = make_tiled_copy( - atom_store_C{}.with(C, n, m, n), - Layout>>{}, - make_layout(make_shape(get<0>(typename traits_store_C::Shape_MN{}), - get<1>(typename traits_store_C::Shape_MN{})/ Int{}))); - auto thread_idx = ThreadIdxX(); - auto mma = make_tiled_mma( - MMA_Atom{}, - Layout, - Int, _1>>{}); - auto thr_mma = mma.get_thread_slice(thread_idx); - auto tCrA = thr_mma.partition_fragment_A(gA(_, _, 0)); - auto tCrB = thr_mma.partition_fragment_B(gB(_, _, 0)); - auto tCrC = thr_mma.partition_fragment_C(gC); - - auto tiled_copy_A = make_tiled_copy_A(copy_a, mma); - auto thr_copy_A = tiled_copy_A.get_thread_slice(thread_idx); - auto tCrA_copy_view = thr_copy_A.retile_D(tCrA); - - auto tiled_copy_B = make_tiled_copy_B(copy_b, mma); - auto thr_copy_B = tiled_copy_B.get_thread_slice(thread_idx); - auto tCrB_copy_view = thr_copy_B.retile_D(tCrB); - - auto tiled_copy_C = make_tiled_copy_C(copy_c, mma); - auto thr_copy_C = tiled_copy_C.get_thread_slice(thread_idx); - auto tCrC_copy_view = thr_copy_C.retile_D(tCrC); - - clear(tCrC); - -#if CUTLASS_ENABLE_DEBUG_PRINTS - if (thread(LOG_THREAD, LOG_GROUP)) { - print("===================== A :\n"); - print(" mA : "); print(mA); print("\n"); - print(" gA : "); print(gA); print("\n"); - print("tCrA_copy_view : "); print(tCrA_copy_view); print("\n"); - print(" tCrA : "); print(tCrA); print("\n"); - - print("===================== B :\n"); - print(" mB : "); print(mB); print("\n"); - print(" gB : "); print(gB); print("\n"); - print("tCrB_copy_view : "); print(tCrB_copy_view); print("\n"); - print(" tCrB : "); print(tCrB); print("\n"); - - print("===================== C :\n"); - print(" mC : "); print(mC); print("\n"); - print(" gC : "); print(gC); print("\n"); - print("tCrC_copy_view : "); print(tCrC_copy_view); print("\n"); - print(" tCrC : "); print(tCrC); print("\n"); - } -#endif - - auto sg_per_wg_x = wg_tile_n / sg_tile_n; - const int m_coord = BlockIdxX() * wg_tile_m + - (get_sub_group_id() / sg_per_wg_x) * sg_tile_m; - const int n_coord = BlockIdxY() * wg_tile_n + - (get_sub_group_id() % sg_per_wg_x) * sg_tile_n; - const int l_coord = BlockIdxZ(); - - auto k_tile_max = size<2>(gA); - for (int k_tile = 0; k_tile < k_tile_max; ++k_tile) { - Tensor blk_tgA = tiled_copy_A.get_pvc_tensor( - make_coord(k_tile * sg_tile_k, m_coord, l_coord), - tCrA_copy_view.shape(), - typename traits_load_A::Shape_MN{}, seq<1,0>{}); - Tensor blk_tgB = tiled_copy_B.get_pvc_tensor( - make_coord(n_coord, k_tile * sg_tile_k, l_coord), - tCrB_copy_view.shape(), - typename traits_load_B::Shape_MN{}); - - copy(tiled_copy_A, blk_tgA, tCrA_copy_view); - copy(tiled_copy_B, blk_tgB, tCrB_copy_view); - - // Compute gemm on mma-partitioned smem - for (int i = 0; i < sg_tile_k / SUBGROUP_SIZE; i++) { - gemm(mma, tCrA(_, _, i), tCrB(_, _, i), tCrC); - } - } - - Tensor blk_tgC = tiled_copy_C.get_pvc_tensor( - make_coord(m_coord, n_coord, l_coord), tCrC_copy_view.shape(), - typename traits_store_C::Shape_MN{}); - copy(copy_c, tCrC_copy_view, blk_tgC); - } -}; - -TEST(PVC_CuTe_Xe, gemm_col_col_bf16_bf16_float_32x128x64) { - run>(32, 128, 64); -} - -TEST(PVC_CuTe_Xe, gemm_col_col_bf16_bf16_float_16x256x64) { - run>(16, 256, 64); -} - -TEST(PVC_CuTe_Xe, gemm_col_col_bf16_bf16_float_64x1024x64) { - run>(64, 1024, 64); -} - -TEST(PVC_CuTe_Xe, gemm_col_col_bf16_bf16_float_128x128x64) { - run>(128, 128, 64); -} -TEST(PVC_CuTe_Xe, gemm_col_col_bf16_bf16_float_32x1024x1024) { - run>(32, 1024, 1024); -} - -TEST(PVC_CuTe_Xe, gemm_col_col_bf16_bf16_float_4096x4096x256) { - run>(4096, 4096, 256); -} - -TEST(PVC_CuTe_Xe, gemm_col_col_bf16_bf16_float_1024x2048x512) { - run>(1024, 2048, 512); -} - -TEST(PVC_CuTe_Xe, gemm_col_col_bf16_bf16_float_1026x2048x512) { - run>(1026, 2048, 512); -} - -TEST(PVC_CuTe_Xe, gemm_col_col_bf16_bf16_float_1024x2050x512) { - run>(1024, 2050, 512); -} - -TEST(PVC_CuTe_Xe, gemm_col_col_bf16_bf16_float_1026x2050x256) { - run>(1026, 2050, 256); -} - -TEST(PVC_CuTe_Xe, gemm_col_col_bf16_bf16_float_512x1024x512) { - run>(512, 1024, 512); -} diff --git a/test/unit/cute/intel_xe/gemm_col_row.cpp b/test/unit/cute/intel_xe/gemm_col_row.cpp deleted file mode 100644 index a3c172247..000000000 --- a/test/unit/cute/intel_xe/gemm_col_row.cpp +++ /dev/null @@ -1,236 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#include "gemm_utils.hpp" - -template -struct gemm_device_col_row { - using TA = dtype_a; - using TB = dtype_b; - using TC = dtype_c; - - static constexpr bool is_a_row_major = false; - static constexpr bool is_b_row_major = true; - - static constexpr uint32_t wg_tile_m = wg_m; - static constexpr uint32_t wg_tile_n = wg_n; - static constexpr uint32_t sg_tile_m = sg_m; - static constexpr uint32_t sg_tile_n = sg_n; - static constexpr uint32_t sg_tile_k = sg_k; - - static void func(TA const *A, TB const *B, TC *C, uint32_t m, uint32_t n, - uint32_t k) { - - // Represent the full tensors - Tensor mA = make_tensor(make_gmem_ptr(A), - make_layout(make_shape(m, k), make_stride(1, m))); - Tensor mB = make_tensor(make_gmem_ptr(B), - make_layout(make_shape(k, n), make_stride(n, 1))); - Tensor mC = make_tensor(make_gmem_ptr(C), - make_layout(make_shape(m, n), make_stride(n, 1))); - - // Get the appropriate blocks for this thread block - auto cta_coord = make_coord(BlockIdxX(), - BlockIdxY(), _); - - auto cta_tiler = - make_shape(Int{}, Int{}, Int{}); - Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); - Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step{}); - Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); - - using traits_load_A = Copy_Traits; - using atom_load_A = Copy_Atom; - TiledCopy copy_a = make_tiled_copy( - atom_load_A{}.with(A, m, k, m), Layout, _1>>{}, - make_layout(make_shape(get<1>(typename traits_load_A::Shape_MN{}), - get<0>(typename traits_load_A::Shape_MN{}) / Int{}))); - - using traits_load_B = Copy_Traits; - using atom_load_B = Copy_Atom; - TiledCopy copy_b = make_tiled_copy( - atom_load_B{}.with(B, n, k, n), Layout>>{}, - make_layout(make_shape(get<0>(typename traits_load_B::Shape_MN{}), - get<1>(typename traits_load_B::Shape_MN{}) / Int{}))); - - using traits_store_C = Copy_Traits; - using atom_store_C = Copy_Atom; - TiledCopy copy_c = make_tiled_copy( - atom_store_C{}.with(C, n, m, n), - Layout>>{}, - make_layout(make_shape(get<0>(typename traits_store_C::Shape_MN{}), - get<1>(typename traits_store_C::Shape_MN{}) / Int{}))); - - auto thread_idx = ThreadIdxX(); - auto mma = make_tiled_mma( - MMA_Atom{}, - Layout, - Int, _1>>{}); - auto thr_mma = mma.get_thread_slice(thread_idx); - auto tCrA = thr_mma.partition_fragment_A(gA(_, _, 0)); - auto tCrB = thr_mma.partition_fragment_B(gB(_, _, 0)); - auto tCrC = thr_mma.partition_fragment_C(gC); - - auto tiled_copy_A = make_tiled_copy_A(copy_a, mma); - auto thr_copy_A = tiled_copy_A.get_thread_slice(thread_idx); - auto tCrA_copy_view = thr_copy_A.retile_D(tCrA); - - auto tiled_copy_B = make_tiled_copy_B(copy_b, mma); - auto thr_copy_B = tiled_copy_B.get_thread_slice(thread_idx); - auto tCrB_copy_view = thr_copy_B.retile_D(tCrB); - - auto tiled_copy_C = make_tiled_copy_C(copy_c, mma); - auto thr_copy_C = tiled_copy_C.get_thread_slice(thread_idx); - auto tCrC_copy_view = thr_copy_C.retile_D(tCrC); - - clear(tCrC); - -#if CUTLASS_ENABLE_DEBUG_PRINTS - if (thread(LOG_THREAD, LOG_GROUP)) { - print("===================== A :\n"); - print(" mA : "); print(mA); print("\n"); - print(" gA : "); print(gA); print("\n"); - print("tCrA_copy_view : "); print(tCrA_copy_view); print("\n"); - print(" tCrA : "); print(tCrA); print("\n"); - - print("===================== B :\n"); - print(" mB : "); print(mB); print("\n"); - print(" gB : "); print(gB); print("\n"); - print("tCrB_copy_view : "); print(tCrB_copy_view); print("\n"); - print(" tCrB : "); print(tCrB); print("\n"); - - print("===================== C :\n"); - print(" mC : "); print(mC); print("\n"); - print(" gC : "); print(gC); print("\n"); - print("tCrC_copy_view : "); print(tCrC_copy_view); print("\n"); - print(" tCrC : "); print(tCrC); print("\n"); - } -#endif - - auto sg_per_wg_x = wg_tile_n / sg_tile_n; - const int m_coord = BlockIdxX() * wg_tile_m + - (get_sub_group_id() / sg_per_wg_x) * sg_tile_m; - const int n_coord = BlockIdxY() * wg_tile_n + - (get_sub_group_id() % sg_per_wg_x) * sg_tile_n; - const int l_coord = BlockIdxZ(); - - auto k_tile_max = size<2>(gA); - for (int k_tile = 0; k_tile < k_tile_max; ++k_tile) { - Tensor blk_tgA = tiled_copy_A.get_pvc_tensor( - make_coord(k_tile * sg_tile_k, m_coord, l_coord), - tCrA_copy_view.shape(), typename traits_load_A::Shape_MN{}, seq<1,0>{}); - Tensor blk_tgB = tiled_copy_B.get_pvc_tensor( - make_coord(k_tile * sg_tile_k, n_coord, l_coord), - tCrB_copy_view.shape(), typename traits_load_B::Shape_MN{}, seq<1,0>{}); - - copy(tiled_copy_A, blk_tgA, tCrA_copy_view); - copy(tiled_copy_B, blk_tgB, tCrB_copy_view); - - // Compute gemm on mma-partitioned smem - for (int i = 0; i < sg_tile_k / SUBGROUP_SIZE; i++) { - gemm(mma, tCrA(_, _, i), tCrB(_, _, i), tCrC); - } - } - - Tensor blk_tgC = tiled_copy_C.get_pvc_tensor( - make_coord(m_coord, n_coord, l_coord), tCrC_copy_view.shape(), - typename traits_store_C::Shape_MN{}); - copy(copy_c, tCrC_copy_view, blk_tgC); - } -}; - -TEST(PVC_CuTe_Xe, gemm_col_row_bf16_bf16_float_32x128x64) { - run>(32, 128, 64); -} - -TEST(PVC_CuTe_Xe, gemm_col_row_bf16_bf16_float_16x256x64) { - run>(16, 256, 64); -} - -TEST(PVC_CuTe_Xe, gemm_col_row_bf16_bf16_float_64x1024x64) { - run>(64, 1024, 64); -} - -TEST(PVC_CuTe_Xe, gemm_col_row_bf16_bf16_float_128x128x64) { - run>(128, 128, 64); -} -TEST(PVC_CuTe_Xe, gemm_col_row_bf16_bf16_float_32x1024x1024) { - run>(32, 1024, 1024); -} - -TEST(PVC_CuTe_Xe, gemm_col_row_bf16_bf16_float_4096x4096x256) { - run>(4096, 4096, 256); -} - -TEST(PVC_CuTe_Xe, gemm_col_row_bf16_bf16_float_1024x2048x512) { - run>(1024, 2048, 512); -} - -TEST(PVC_CuTe_Xe, gemm_col_row_bf16_bf16_float_1026x2048x512) { - run>(1026, 2048, 512); -} - -TEST(PVC_CuTe_Xe, gemm_col_row_bf16_bf16_float_1024x2050x512) { - run>(1024, 2050, 512); -} - -TEST(PVC_CuTe_Xe, gemm_col_row_bf16_bf16_float_1026x2050x256) { - run>(1026, 2050, 256); -} - -TEST(PVC_CuTe_Xe, gemm_col_row_bf16_bf16_float_512x1024x512) { - run>(512, 1024, 512); -} diff --git a/test/unit/cute/intel_xe/gemm_common.hpp b/test/unit/cute/intel_xe/gemm_common.hpp new file mode 100755 index 000000000..a72b7cdae --- /dev/null +++ b/test/unit/cute/intel_xe/gemm_common.hpp @@ -0,0 +1,203 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "utils.hpp" + +using namespace cutlass::detail; +using namespace cute::detail; + +template +struct gemm_device_partition_fragment_abc { + using TA = dtype_a; + using TB = dtype_b; + using TC = dtype_c; + + static constexpr bool is_a_row_major = std::is_same_v; + static constexpr bool is_b_row_major = std::is_same_v;; + + static constexpr uint32_t wg_tile_m = wg_m; + static constexpr uint32_t wg_tile_n = wg_n; + static constexpr uint32_t sg_tile_m = sg_m; + static constexpr uint32_t sg_tile_n = sg_n; + static constexpr uint32_t sg_tile_k = sg_k; + + static void func(TA const *A, TB const *B, TC *C, uint32_t m, uint32_t n, + uint32_t k) { + using namespace cute; + + // Represent the full tensors + Tensor mA = make_tensor(make_gmem_ptr(A), make_shape(m, k), layout_a{}); + Tensor mB = make_tensor(make_gmem_ptr(B), make_shape(n, k), layout_b{}); + Tensor mC = make_tensor(make_gmem_ptr(C), make_shape(m, n), cute::LayoutRight{}); + + // Get the appropriate blocks for this thread block + auto cta_coord = make_coord(BlockIdxX(), BlockIdxY(), _); // (m,n,k) + + auto cta_tiler = + make_shape(Int{}, Int{}, Int{}); + Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); + Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step{}); + Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); + + using traits_load_A = Copy_Traits; + using atom_load_A = Copy_Atom; + auto copy_a = make_xe_2d_copy( + atom_load_A{}.with(mA), Layout>>{}); + + using traits_load_B = Copy_Traits; + using atom_load_B = Copy_Atom; + auto copy_b = make_xe_2d_copy( + atom_load_B{}.with(mB), Layout>>{}); + + using traits_store_C = Copy_Traits; + using atom_store_C = Copy_Atom; + auto copy_c = make_xe_2d_copy( + atom_store_C{}.with(mC), Layout>>{}); + + auto thread_idx = ThreadIdxX(); + + auto mma = make_tiled_mma( + MMA_Atom{}, + Layout, + Int>>{}); + auto thrd_mma = mma.get_thread_slice(thread_idx); + + Tensor fragment_A = thrd_mma.partition_fragment_A(gA(_, _, 0)); + Tensor fragment_B = thrd_mma.partition_fragment_B(gB(_, _, 0)); + Tensor fragment_C = thrd_mma.partition_fragment_C(gC); + + + auto thr_copy_a = copy_a.get_slice(thread_idx); + auto copy_view_A = thr_copy_a.retile_D(fragment_A); + + auto thr_copy_b = copy_b.get_slice(thread_idx); + auto copy_view_B = thr_copy_b.retile_D(fragment_B); + + auto thr_copy_c = copy_c.get_slice(thread_idx); + auto copy_view_C = thr_copy_c.retile_D(fragment_C); + + Tensor mma_A = thr_copy_a.retile_MMA(thrd_mma, fragment_A); + Tensor mma_B = thr_copy_b.retile_MMA(thrd_mma, fragment_B); + + clear(fragment_C); + +#if CUTLASS_ENABLE_DEBUG_PRINTS + if (thread(LOG_THREAD, LOG_GROUP)) { + print("===================== A :\n"); + print(" mA : "); + print(mA); + print("\n"); + print(" gA : "); + print(gA); + print("\n"); + print(" fragment_A : "); + print(fragment_A); + print("\n"); + print(" copy_view_A : "); + print(copy_view_A); + print("\n"); + + print("===================== B :\n"); + print(" mB : "); + print(mB); + print("\n"); + print(" gB : "); + print(gB); + print("\n"); + print(" fragment_B : "); + print(fragment_B); + print("\n"); + print(" copy_view_B : "); + print(copy_view_B); + print("\n"); + + print("===================== C :\n"); + print(" mC : "); + print(mC); + print("\n"); + print(" gC : "); + print(gC); + print("\n"); + print(" fragment_C : "); + print(fragment_C); + print("\n"); + print(" copy_view_C : "); + print(copy_view_C); + print("\n"); + } +#endif + + auto sg_per_wg_x = wg_tile_n / sg_tile_n; + const int m_coord = BlockIdxX() * wg_tile_m + + (get_sub_group_id() / sg_per_wg_x) * sg_tile_m; + const int n_coord = BlockIdxY() * wg_tile_n + + (get_sub_group_id() % sg_per_wg_x) * sg_tile_n; + const int l_coord = BlockIdxZ(); + + auto k_tile_max = size<2>(gA); + for (int k_tile = 0; k_tile < k_tile_max; ++k_tile) { + Tensor blk_tgA = copy_a.get_pvc_tensor(make_coord(m_coord, k_tile * sg_tile_k, l_coord), + copy_view_A.shape()); + Tensor blk_tgB = copy_b.get_pvc_tensor(make_coord(n_coord, k_tile * sg_tile_k, l_coord), + copy_view_B.shape()); + +#if CUTLASS_ENABLE_DEBUG_PRINTS + if (thread(LOG_THREAD, LOG_GROUP) && k_tile == 1) { + print("blk_tgA : "); + print(blk_tgA); + print("\n"); + print("blk_tgB : "); + print(blk_tgB); + print("\n"); + } +#endif + + // Copy gmem to rmem for k_tile+1 with tA|tB thread-partitioned tensors + copy(copy_a, blk_tgA, copy_view_A); + copy(copy_b, blk_tgB, copy_view_B); + + // Compute gemm on mma-partitioned smem + cute::gemm(mma, mma_A, mma_B, fragment_C); + } + + Tensor blk_tgC = + copy_c.get_pvc_tensor(make_coord(m_coord, n_coord, l_coord), fragment_C.shape()); + + copy(copy_c, fragment_C, blk_tgC); + } +}; diff --git a/test/unit/cute/intel_xe/gemm_data_type.cpp b/test/unit/cute/intel_xe/gemm_data_type.cpp new file mode 100755 index 000000000..ea9955ce8 --- /dev/null +++ b/test/unit/cute/intel_xe/gemm_data_type.cpp @@ -0,0 +1,85 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "gemm_common.hpp" +#include "utils.hpp" + +TEST(PVC_CuTe_Xe, gemm_float_bf16_bf16_float) { + run>(32, 128, 64); +} + +TEST(PVC_CuTe_Xe, gemm_int32_int8_int8_int32) { + run>(16, 256, 64); +} + +TEST(PVC_CuTe_Xe, gemm_int32_uint8_uint8_int32) { + run>(32, 128, 64); +} + +TEST(PVC_CuTe_Xe, gemm_float_fp16_fp16_float) { + run>(16, 256, 64); +} + +// TODO: don't know how to enable this +#if 0 +TEST(PVC_CuTe_Xe, gemm_float_tf32_tf32_float_XE_1x16x8_F32TF32TF32F32_TT) { + run>(256, 512, 1024); +} +#endif + +TEST(PVC_CuTe_Xe, gemm_float_tf32_tf32_float_XE_2x16x8_F32TF32TF32F32_TT) { + run>(256, 512, 1024); +} + +TEST(PVC_CuTe_Xe, gemm_float_tf32_tf32_float_XE_4x16x8_F32TF32TF32F32_TT) { + run>(256, 512, 1024); +} + +TEST(PVC_CuTe_Xe, gemm_float_tf32_tf32_float_XE_8x16x8_F32TF32TF32F32_TT) { + run>(256, 512, 1024); +} diff --git a/test/unit/cute/intel_xe/gemm_layout.cpp b/test/unit/cute/intel_xe/gemm_layout.cpp new file mode 100755 index 000000000..ef83e6d4c --- /dev/null +++ b/test/unit/cute/intel_xe/gemm_layout.cpp @@ -0,0 +1,69 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "gemm_common.hpp" +#include "utils.hpp" + +TEST(PVC_CuTe_Xe, gemm_RowMajor_RowMajor) { + run>( + 512, 256, 1024); +} + +TEST(PVC_CuTe_Xe, gemm_RowMajor_ColumnMajor) { + run>( + 128, 256, 512); +} + +TEST(PVC_CuTe_Xe, gemm_ColumnMajor_RowMajor) { + run>( + 256, 512, 1024); +} + +TEST(PVC_CuTe_Xe, gemm_ColumnMajor_ColumnMajor) { + run>( + 256, 512, 1024); +} diff --git a/test/unit/cute/intel_xe/gemm_partition_fragment_abc.cpp b/test/unit/cute/intel_xe/gemm_partition_fragment_abc.cpp index 6224e4a6e..ad6fb0055 100755 --- a/test/unit/cute/intel_xe/gemm_partition_fragment_abc.cpp +++ b/test/unit/cute/intel_xe/gemm_partition_fragment_abc.cpp @@ -29,159 +29,8 @@ * **************************************************************************************************/ -#include "gemm_utils.hpp" - -template -struct gemm_device_partition_fragment_abc { - using TA = dtype_a; - using TB = dtype_b; - using TC = dtype_c; - - static constexpr bool is_a_row_major = true; - static constexpr bool is_b_row_major = true; - - static constexpr uint32_t wg_tile_m = wg_m; - static constexpr uint32_t wg_tile_n = wg_n; - static constexpr uint32_t sg_tile_m = sg_m; - static constexpr uint32_t sg_tile_n = sg_n; - static constexpr uint32_t sg_tile_k = sg_k; - - static void func(TA const *A, TB const *B, TC *C, uint32_t m, uint32_t n, - uint32_t k) { - using namespace cute; - - // Represent the full tensors - Tensor mA = make_tensor(make_gmem_ptr(A), - make_layout(make_shape(m, k), make_stride(k, 1))); - Tensor mB = make_tensor(make_gmem_ptr(B), - make_layout(make_shape(k, n), make_stride(n, 1))); - Tensor mC = make_tensor(make_gmem_ptr(C), - make_layout(make_shape(m, n), make_stride(n, 1))); - - // Get the appropriate blocks for this thread block - auto cta_coord = make_coord(BlockIdxX(), - BlockIdxY(), _); // (m,n,k) - - auto cta_tiler = - make_shape(Int{}, Int{}, Int{}); - Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); - Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step{}); - Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); - - using traits_load_A = Copy_Traits; - using atom_load_A = Copy_Atom; - TiledCopy copy_a = make_tiled_copy( - atom_load_A{}.with(A, k, m, k), Layout>>{}, - make_layout(make_shape(get<0>(typename traits_load_A::Shape_MN{}), - get<1>(typename traits_load_A::Shape_MN{}) / Int{}))); - - using traits_load_B = Copy_Traits; - using atom_load_B = Copy_Atom; - TiledCopy copy_b = make_tiled_copy( - atom_load_B{}.with(B, n, k, n), Layout>>{}, - make_layout(make_shape(get<0>(typename traits_load_B::Shape_MN{}), - get<1>(typename traits_load_B::Shape_MN{}) / Int{}))); - - using traits_store_C = Copy_Traits; - using atom_store_C = Copy_Atom; - TiledCopy copy_c = make_tiled_copy( - atom_store_C{}.with(C, n, m, n), - Layout>>{}, - make_layout(make_shape(get<0>(typename traits_store_C::Shape_MN{}), - get<1>(typename traits_store_C::Shape_MN{}) / Int{}))); - - auto thread_idx = ThreadIdxX(); - - TiledMMA mma = make_tiled_mma( - MMA_Atom{}, - Layout, - Int>>{}); - auto thrd_mma = mma.get_thread_slice(thread_idx); - - Tensor fragment_A = thrd_mma.partition_fragment_A(gA(_, _, 0)); - Tensor fragment_temp = thrd_mma.partition_fragment_B(gB(_, _, 0)); - Tensor fragment_B = make_tensor( - static_cast(fragment_temp).data(), - make_shape(size<0>(fragment_temp.shape()), - size<2>(fragment_temp.shape()), - size<1>(fragment_temp.shape()))); - Tensor fragment_C = thrd_mma.partition_fragment_C(gC); - - ThrCopy thr_copy_a = copy_a.get_slice(thread_idx); - auto copy_view_A = thr_copy_a.retile_D(fragment_A); - - ThrCopy thr_copy_b = copy_b.get_slice(thread_idx); - auto copy_view_B = thr_copy_b.retile_D(fragment_B); - - ThrCopy thr_copy_c = copy_c.get_slice(thread_idx); - auto copy_view_C = thr_copy_c.retile_D(fragment_C); - - clear(fragment_C); - -#if CUTLASS_ENABLE_DEBUG_PRINTS - if (thread(LOG_THREAD, LOG_GROUP)) { - print("===================== A :\n"); - print(" mA : "); print(mA); print("\n"); - print(" gA : "); print(gA); print("\n"); - print(" fragment_A : "); print(fragment_A); print("\n"); - print(" copy_view_A : "); print(copy_view_A); print("\n"); - - print("===================== B :\n"); - print(" mB : "); print(mB); print("\n"); - print(" gB : "); print(gB); print("\n"); - print(" fragment_B : "); print(fragment_B); print("\n"); - print(" copy_view_B : "); print(copy_view_B); print("\n"); - - print("===================== C :\n"); - print(" mC : "); print(mC); print("\n"); - print(" gC : "); print(gC); print("\n"); - print(" fragment_C : "); print(fragment_C); print("\n"); - print(" copy_view_C : "); print(copy_view_C); print("\n"); - } -#endif - - auto sg_per_wg_x = wg_tile_n / sg_tile_n; - const int m_coord = BlockIdxX() * wg_tile_m + - (get_sub_group_id() / sg_per_wg_x) * sg_tile_m; - const int n_coord = BlockIdxY() * wg_tile_n + - (get_sub_group_id() % sg_per_wg_x) * sg_tile_n; - const int l_coord = BlockIdxZ(); - - auto k_tile_max = size<2>(gA); - Tensor blk_tgA = copy_a.get_pvc_tensor( - make_coord(m_coord, 0, l_coord), append<4>(copy_view_A.shape(), k_tile_max), - append<3>(typename traits_load_A::Shape_MN{}, sg_tile_k), seq<0, 1, 1>{}); - Tensor blk_tgB = copy_b.get_pvc_tensor( - make_coord(0, n_coord, l_coord), append<4>(copy_view_B.shape(), k_tile_max), - append<3>(typename traits_load_B::Shape_MN{}, sg_tile_k), seq<0, 1, 0>{}); - for (int k_tile = 0; k_tile < k_tile_max; ++k_tile) { -#if CUTLASS_ENABLE_DEBUG_PRINTS - if (thread(LOG_THREAD, LOG_GROUP) && k_tile == 1) { - print("blk_tgA : "); print(blk_tgA); print("\n"); - print("blk_tgB : "); print(blk_tgB); print("\n"); - } -#endif - - // Copy gmem to rmem for k_tile+1 with tA|tB thread-partitioned tensors - copy(copy_a, blk_tgA(_, _, _, k_tile), copy_view_A); - copy(copy_b, blk_tgB(_, _, _, k_tile), copy_view_B); - - // Compute gemm on mma-partitioned smem - for (int i = 0; i < sg_tile_k / SUBGROUP_SIZE; i++) { - gemm(mma, fragment_A(_, _, i), fragment_B(_, i, _), fragment_C); - } - } - - Tensor blk_tgC = copy_c.get_pvc_tensor( - make_coord(m_coord, n_coord, l_coord), fragment_C.shape(), - typename traits_store_C::Shape_MN{}); - - copy(copy_c, fragment_C, blk_tgC); - } -}; +#include "gemm_common.hpp" +#include "utils.hpp" TEST(PVC_CuTe_Xe, gemm_partition_fragment_abc_bf16_bf16_float_32x128x64) { run{}, Int{}, Int{}); @@ -82,26 +81,21 @@ struct gemm_device_partition_sd { local_tile(gC, make_shape(Int{}, Int{}), make_coord(sg_id / sg_per_wg_x, sg_id % sg_per_wg_x)); - using traits_load_A = Copy_Traits; + using traits_load_A = Copy_Traits; using atom_load_A = Copy_Atom; - TiledCopy copy_a = make_tiled_copy( - atom_load_A{}.with(A, k, m, k), Layout>>{}, - make_layout(make_shape(get<0>(typename traits_load_A::Shape_MN{}), - get<1>(typename traits_load_A::Shape_MN{}) / Int{}))); + TiledCopy copy_a = make_xe_2d_copy( + atom_load_A{}.with(mA), Layout>>{}); - using traits_load_B = Copy_Traits; + using traits_load_B = Copy_Traits; using atom_load_B = Copy_Atom; - TiledCopy copy_b = make_tiled_copy( - atom_load_B{}.with(B, n, k, n), Layout>>{}, - make_layout(make_shape(get<0>(typename traits_load_B::Shape_MN{}), - get<1>(typename traits_load_B::Shape_MN{}) / Int{}))); - using traits_store_C = Copy_Traits; + TiledCopy copy_b = make_xe_2d_copy( + atom_load_B{}.with(mB), Layout>>{}); + + using traits_store_C = Copy_Traits; using atom_store_C = Copy_Atom; - TiledCopy copy_c = make_tiled_copy( - atom_store_C{}.with(C, n, m, n), - Layout>>{}, - make_layout(make_shape(get<0>(typename traits_store_C::Shape_MN{}), - get<1>(typename traits_store_C::Shape_MN{}) / Int{}))); + TiledCopy copy_c = make_xe_2d_copy( + atom_store_C{}.with(mC), Layout>>{}); + TiledMMA mma = make_tiled_mma( MMA_Atom{}, Layout, @@ -128,42 +122,68 @@ struct gemm_device_partition_sd { clear(fragment_C); #if CUTLASS_ENABLE_DEBUG_PRINTS - if (thread(LOG_THREAD, LOG_GROUP)) { - print("===================== A :\n"); - print(" mA : "); print(mA); print("\n"); - print(" gA : "); print(gA); print("\n"); - print("tgA : "); print(tgA); print("\n"); - print("fragment_A : "); print(fragment_A); print("\n"); - - print("===================== B :\n"); - print(" mB : "); print(mB); print("\n"); - print(" gB : "); print(gB); print("\n"); - print("tgB : "); print(tgB); print("\n"); - print("fragment_B : "); print(fragment_B); print("\n"); - - print("===================== C :\n"); - print(" mC : "); print(mC); print("\n"); - print(" gC : "); print(gC); print("\n"); - print("tgC : "); print(tgC); print("\n"); - print("fragment_C : "); print(fragment_C); print("\n"); - } + if (thread(LOG_THREAD, LOG_GROUP)) { + print("===================== A :\n"); + print(" mA : "); + print(mA); + print("\n"); + print(" gA : "); + print(gA); + print("\n"); + print("tgA : "); + print(tgA); + print("\n"); + print("fragment_A : "); + print(fragment_A); + print("\n"); + + print("===================== B :\n"); + print(" mB : "); + print(mB); + print("\n"); + print(" gB : "); + print(gB); + print("\n"); + print("tgB : "); + print(tgB); + print("\n"); + print("fragment_B : "); + print(fragment_B); + print("\n"); + + print("===================== C :\n"); + print(" mC : "); + print(mC); + print("\n"); + print(" gC : "); + print(gC); + print("\n"); + print("tgC : "); + print(tgC); + print("\n"); + print("fragment_C : "); + print(fragment_C); + print("\n"); + } #endif auto k_tile_max = size<3>(tgA); for (int k_tile = 0; k_tile < k_tile_max; ++k_tile) { - Tensor blk_tgA = copy_a.get_pvc_tensor( - make_coord(m_coord, k_tile * sg_tile_k, l_coord), fragment_A.shape(), - typename traits_load_A::Shape_MN{}); - Tensor blk_tgB = copy_b.get_pvc_tensor( - make_coord(k_tile * sg_tile_k, n_coord, l_coord), fragment_B.shape(), - typename traits_load_B::Shape_MN{}, seq<1,0>{}); + Tensor blk_tgA = copy_a.get_pvc_tensor(make_coord(m_coord, k_tile * sg_tile_k, l_coord), + fragment_A.shape()); + Tensor blk_tgB = copy_b.get_pvc_tensor(make_coord(n_coord, k_tile * sg_tile_k, l_coord), + fragment_B.shape()); #if CUTLASS_ENABLE_DEBUG_PRINTS - if (thread(LOG_THREAD, LOG_GROUP) && k_tile == 1) { - print("blk_tgA : "); print(blk_tgA); print("\n"); - print("blk_tgB : "); print(blk_tgB); print("\n"); - } + if (thread(LOG_THREAD, LOG_GROUP) && k_tile == 1) { + print("blk_tgA : "); + print(blk_tgA); + print("\n"); + print("blk_tgB : "); + print(blk_tgB); + print("\n"); + } #endif // Copy gmem to rmem for k_tile+1 with tA|tB thread-partitioned tensors @@ -172,13 +192,12 @@ struct gemm_device_partition_sd { // Compute gemm on mma-partitioned smem for (int i = 0; i < sg_tile_k / SUBGROUP_SIZE; i++) { - gemm(mma, fragment_A(_, _, i), fragment_B(_, _, i), fragment_C); + cute::gemm(mma, fragment_A(_, _, i), fragment_B(_, _, i), fragment_C); } } - Tensor blk_tgC = copy_c.get_pvc_tensor( - make_coord(m_coord, n_coord, l_coord), fragment_C.shape(), - typename traits_store_C::Shape_MN{}); + Tensor blk_tgC = + copy_c.get_pvc_tensor(make_coord(m_coord, n_coord, l_coord), fragment_C.shape()); copy(copy_c, fragment_C, blk_tgC); } diff --git a/test/unit/cute/intel_xe/gemm_row_col.cpp b/test/unit/cute/intel_xe/gemm_row_col.cpp deleted file mode 100755 index 57bd808d4..000000000 --- a/test/unit/cute/intel_xe/gemm_row_col.cpp +++ /dev/null @@ -1,238 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#include "gemm_utils.hpp" - -template -struct gemm_device_row_col { - using TA = dtype_a; - using TB = dtype_b; - using TC = dtype_c; - - static constexpr bool is_a_row_major = true; - static constexpr bool is_b_row_major = false; - - static constexpr uint32_t wg_tile_m = wg_m; - static constexpr uint32_t wg_tile_n = wg_n; - static constexpr uint32_t sg_tile_m = sg_m; - static constexpr uint32_t sg_tile_n = sg_n; - static constexpr uint32_t sg_tile_k = sg_k; - - static void func(TA const *A, TB const *B, TC *C, uint32_t m, uint32_t n, - uint32_t k) { - - // Represent the full tensors - Tensor mA = make_tensor(make_gmem_ptr(A), - make_layout(make_shape(m, k), make_stride(k, 1))); - Tensor mB = make_tensor(make_gmem_ptr(B), - make_layout(make_shape(k, n), make_stride(1, k))); - Tensor mC = make_tensor(make_gmem_ptr(C), - make_layout(make_shape(m, n), make_stride(n, 1))); - - // Get the appropriate blocks for this thread block - auto cta_coord = make_coord(BlockIdxX(), - BlockIdxY(), _); - - auto cta_tiler = - make_shape(Int{}, Int{}, Int{}); - Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); - Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step{}); - Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); - - using traits_load_A = Copy_Traits; - using atom_load_A = Copy_Atom; - TiledCopy copy_a = make_tiled_copy( - atom_load_A{}.with(A, k, m, k), Layout>>{}, - make_layout(make_shape(get<0>(typename traits_load_A::Shape_MN{}), - get<1>(typename traits_load_A::Shape_MN{}) / Int{}))); - - using traits_load_B = Copy_Traits; - using atom_load_B = Copy_Atom; - TiledCopy copy_b = make_tiled_copy( - atom_load_B{}.with(B, k, n, k), Layout, _1>>{}, - make_layout(make_shape(get<1>(typename traits_load_B::Shape_MN{})/ Int{}, - get<0>(typename traits_load_B::Shape_MN{})))); - - using traits_store_C = Copy_Traits; - using atom_store_C = Copy_Atom; - TiledCopy copy_c = make_tiled_copy( - atom_store_C{}.with(C, n, m, n), - Layout>>{}, - make_layout(make_shape(get<0>(typename traits_store_C::Shape_MN{}), - get<1>(typename traits_store_C::Shape_MN{}) / Int{}))); - - auto thread_idx = ThreadIdxX(); - auto mma = make_tiled_mma( - MMA_Atom{}, - Layout, - Int, _1>>{}); - auto thr_mma = mma.get_thread_slice(thread_idx); - auto tCrA = thr_mma.partition_fragment_A(gA(_, _, 0)); - auto tCrB = thr_mma.partition_fragment_B(gB(_, _, 0)); - auto tCrC = thr_mma.partition_fragment_C(gC); - - auto tiled_copy_A = make_tiled_copy_A(copy_a, mma); - auto thr_copy_A = tiled_copy_A.get_thread_slice(thread_idx); - auto tCrA_copy_view = thr_copy_A.retile_D(tCrA); - - auto tiled_copy_B = make_tiled_copy_B(copy_b, mma); - auto thr_copy_B = tiled_copy_B.get_thread_slice(thread_idx); - auto tCrB_copy_view = thr_copy_B.retile_D(tCrB); - - auto tiled_copy_C = make_tiled_copy_C(copy_c, mma); - auto thr_copy_C = tiled_copy_C.get_thread_slice(thread_idx); - auto tCrC_copy_view = thr_copy_C.retile_D(tCrC); - - clear(tCrC); - -#if CUTLASS_ENABLE_DEBUG_PRINTS - if (thread(LOG_THREAD, LOG_GROUP)) { - print("===================== A :\n"); - print(" mA : "); print(mA); print("\n"); - print(" gA : "); print(gA); print("\n"); - print("tCrA_copy_view : "); print(tCrA_copy_view); print("\n"); - print(" tCrA : "); print(tCrA); print("\n"); - - print("===================== B :\n"); - print(" mB : "); print(mB); print("\n"); - print(" gB : "); print(gB); print("\n"); - print("tCrB_copy_view : "); print(tCrB_copy_view); print("\n"); - print(" tCrB : "); print(tCrB); print("\n"); - - print("===================== C :\n"); - print(" mC : "); print(mC); print("\n"); - print(" gC : "); print(gC); print("\n"); - print("tCrC_copy_view : "); print(tCrC_copy_view); print("\n"); - print(" tCrC : "); print(tCrC); print("\n"); - } -#endif - - auto sg_per_wg_x = wg_tile_n / sg_tile_n; - const int m_coord = BlockIdxX() * wg_tile_m + - (get_sub_group_id() / sg_per_wg_x) * sg_tile_m; - const int n_coord = BlockIdxY() * wg_tile_n + - (get_sub_group_id() % sg_per_wg_x) * sg_tile_n; - const int l_coord = BlockIdxZ(); - - auto k_tile_max = size<2>(gA); - for (int k_tile = 0; k_tile < k_tile_max; ++k_tile) { - Tensor blk_tgA = tiled_copy_A.get_pvc_tensor( - make_coord(m_coord, k_tile * sg_tile_k, l_coord), - tCrA_copy_view.shape(), - typename traits_load_A::Shape_MN{}); - Tensor blk_tgB = tiled_copy_B.get_pvc_tensor( - make_coord(n_coord, k_tile * sg_tile_k, l_coord), - tCrB_copy_view.shape(), - typename traits_load_B::Shape_MN{}); - - copy(tiled_copy_A, blk_tgA, tCrA_copy_view); - copy(tiled_copy_B, blk_tgB, tCrB_copy_view); - - // Compute gemm on mma-partitioned smem - for (int i = 0; i < sg_tile_k / SUBGROUP_SIZE; i++) { - gemm(mma, tCrA(_, _, i), tCrB(_, _, i), tCrC); - } - } - - Tensor blk_tgC = tiled_copy_C.get_pvc_tensor( - make_coord(m_coord, n_coord, l_coord), tCrC_copy_view.shape(), - typename traits_store_C::Shape_MN{}); - copy(copy_c, tCrC_copy_view, blk_tgC); - } -}; - -TEST(PVC_CuTe_Xe, gemm_row_col_bf16_bf16_float_32x128x64) { - run>(32, 128, 64); -} - -TEST(PVC_CuTe_Xe, gemm_row_col_bf16_bf16_float_16x256x64) { - run>(16, 256, 64); -} - -TEST(PVC_CuTe_Xe, gemm_row_col_bf16_bf16_float_64x1024x64) { - run>(64, 1024, 64); -} - -TEST(PVC_CuTe_Xe, gemm_row_col_bf16_bf16_float_128x128x64) { - run>(128, 128, 64); -} -TEST(PVC_CuTe_Xe, gemm_row_col_bf16_bf16_float_32x1024x1024) { - run>(32, 1024, 1024); -} - -TEST(PVC_CuTe_Xe, gemm_row_col_bf16_bf16_float_4096x4096x256) { - run>(4096, 4096, 256); -} - -TEST(PVC_CuTe_Xe, gemm_row_col_bf16_bf16_float_1024x2048x512) { - run>(1024, 2048, 512); -} - -TEST(PVC_CuTe_Xe, gemm_row_col_bf16_bf16_float_1026x2048x512) { - run>(1026, 2048, 512); -} - -TEST(PVC_CuTe_Xe, gemm_row_col_bf16_bf16_float_1024x2050x512) { - run>(1024, 2050, 512); -} - -TEST(PVC_CuTe_Xe, gemm_row_col_bf16_bf16_float_1026x2050x256) { - run>(1026, 2050, 256); -} - -TEST(PVC_CuTe_Xe, gemm_row_col_bf16_bf16_float_512x1024x512) { - run>(512, 1024, 512); -} diff --git a/test/unit/cute/intel_xe/gemm_tiled_copy_abc.cpp b/test/unit/cute/intel_xe/gemm_tiled_copy_abc.cpp index a979ce3d2..2215d679e 100755 --- a/test/unit/cute/intel_xe/gemm_tiled_copy_abc.cpp +++ b/test/unit/cute/intel_xe/gemm_tiled_copy_abc.cpp @@ -29,7 +29,7 @@ * **************************************************************************************************/ -#include "gemm_utils.hpp" +#include "utils.hpp" template {}, Int{}, Int{}); @@ -72,25 +71,19 @@ struct gemm_device_tiled_copy_abc { using traits_load_A = Copy_Traits; using atom_load_A = Copy_Atom; - TiledCopy copy_a = make_tiled_copy( - atom_load_A{}.with(A, k, m, k), Layout>>{}, - make_layout(make_shape(get<0>(typename traits_load_A::Shape_MN{}), - get<1>(typename traits_load_A::Shape_MN{}) / Int{}))); + TiledCopy copy_a = make_xe_2d_copy( + atom_load_A{}.with(A, m, k), Layout>>{}); - using traits_load_B = Copy_Traits; + using traits_load_B = Copy_Traits>; using atom_load_B = Copy_Atom; - TiledCopy copy_b = make_tiled_copy( - atom_load_B{}.with(B, n, k, n), Layout>>{}, - make_layout(make_shape(get<0>(typename traits_load_B::Shape_MN{}), - get<1>(typename traits_load_B::Shape_MN{}) / Int{}))); + TiledCopy copy_b = make_xe_2d_copy( + atom_load_B{}.with(B, n, k), Layout>>{}); using traits_store_C = Copy_Traits; using atom_store_C = Copy_Atom; - TiledCopy copy_c = make_tiled_copy( - atom_store_C{}.with(C, n, m, n), - Layout>>{}, - make_layout(make_shape(get<0>(typename traits_store_C::Shape_MN{}), - get<1>(typename traits_store_C::Shape_MN{}) / Int{}))); + TiledCopy copy_c = make_xe_2d_copy( + atom_store_C{}.with(C, m, n, n), Layout>>{}); + auto thread_idx = ThreadIdxX(); auto mma = make_tiled_mma( MMA_Atom{}, @@ -116,25 +109,49 @@ struct gemm_device_tiled_copy_abc { clear(tCrC); #if CUTLASS_ENABLE_DEBUG_PRINTS - if (thread(LOG_THREAD, LOG_GROUP)) { - print("===================== A :\n"); - print(" mA : "); print(mA); print("\n"); - print(" gA : "); print(gA); print("\n"); - print("tCrA_copy_view : "); print(tCrA_copy_view); print("\n"); - print(" tCrA : "); print(tCrA); print("\n"); - - print("===================== B :\n"); - print(" mB : "); print(mB); print("\n"); - print(" gB : "); print(gB); print("\n"); - print("tCrB_copy_view : "); print(tCrB_copy_view); print("\n"); - print(" tCrB : "); print(tCrB); print("\n"); - - print("===================== C :\n"); - print(" mC : "); print(mC); print("\n"); - print(" gC : "); print(gC); print("\n"); - print("tCrC_copy_view : "); print(tCrC_copy_view); print("\n"); - print(" tCrC : "); print(tCrC); print("\n"); - } + if (thread(LOG_THREAD, LOG_GROUP)) { + print("===================== A :\n"); + print(" mA : "); + print(mA); + print("\n"); + print(" gA : "); + print(gA); + print("\n"); + print("tCrA_copy_view : "); + print(tCrA_copy_view); + print("\n"); + print(" tCrA : "); + print(tCrA); + print("\n"); + + print("===================== B :\n"); + print(" mB : "); + print(mB); + print("\n"); + print(" gB : "); + print(gB); + print("\n"); + print("tCrB_copy_view : "); + print(tCrB_copy_view); + print("\n"); + print(" tCrB : "); + print(tCrB); + print("\n"); + + print("===================== C :\n"); + print(" mC : "); + print(mC); + print("\n"); + print(" gC : "); + print(gC); + print("\n"); + print("tCrC_copy_view : "); + print(tCrC_copy_view); + print("\n"); + print(" tCrC : "); + print(tCrC); + print("\n"); + } #endif auto sg_per_wg_x = wg_tile_n / sg_tile_n; @@ -145,28 +162,22 @@ struct gemm_device_tiled_copy_abc { const int l_coord = BlockIdxZ(); auto k_tile_max = size<2>(gA); + for (int k_tile = 0; k_tile < k_tile_max; ++k_tile) { Tensor blk_tgA = tiled_copy_A.get_pvc_tensor( - make_coord(m_coord, k_tile * sg_tile_k, l_coord), - tCrA_copy_view.shape(), - typename traits_load_A::Shape_MN{}); + make_coord(m_coord, k_tile * sg_tile_k, l_coord), tCrA_copy_view.shape()); Tensor blk_tgB = tiled_copy_B.get_pvc_tensor( - make_coord(k_tile * sg_tile_k, n_coord, l_coord), - tCrB_copy_view.shape(), - typename traits_load_B::Shape_MN{}, seq<1,0>{}); + make_coord(n_coord, k_tile * sg_tile_k, l_coord), tCrB_copy_view.shape()); copy(tiled_copy_A, blk_tgA, tCrA_copy_view); copy(tiled_copy_B, blk_tgB, tCrB_copy_view); // Compute gemm on mma-partitioned smem - for (int i = 0; i < sg_tile_k / SUBGROUP_SIZE; i++) { - gemm(mma, tCrA(_, _, i), tCrB(_, _, i), tCrC); - } + cute::gemm(mma, tCrA, tCrB, tCrC); } - Tensor blk_tgC = tiled_copy_C.get_pvc_tensor( - make_coord(m_coord, n_coord, l_coord), tCrC_copy_view.shape(), - typename traits_store_C::Shape_MN{}); + Tensor blk_tgC = tiled_copy_C.get_pvc_tensor(make_coord(m_coord, n_coord, l_coord), + tCrC_copy_view.shape()); copy(copy_c, tCrC_copy_view, blk_tgC); } }; diff --git a/test/unit/cute/intel_xe/mma.cpp b/test/unit/cute/intel_xe/mma.cpp index 5d80d8613..1c0e3d8a6 100755 --- a/test/unit/cute/intel_xe/mma.cpp +++ b/test/unit/cute/intel_xe/mma.cpp @@ -29,11 +29,14 @@ * **************************************************************************************************/ +#include "cutlass/detail/layout.hpp" + #include #include #include #include "cutlass_unit_test.h" +#include "utils.hpp" using namespace cute; using namespace cutlass; @@ -56,8 +59,7 @@ void gemm_device(TA const *A, TB const *B, TC *C, uint32_t m, uint32_t n, make_layout(make_shape(m, n), make_stride(n, 1))); // Get the appropriate blocks for this thread block - auto cta_coord = make_coord(BlockIdxX(), - BlockIdxY(), _); // (m,n,k) + auto cta_coord = make_coord(BlockIdxX(), BlockIdxY(), _); // (m,n,k) auto cta_tiler = make_shape(Int{}, Int{}, Int{}); @@ -93,10 +95,18 @@ void gemm_device(TA const *A, TB const *B, TC *C, uint32_t m, uint32_t n, if (thread(LOG_THREAD)) { print("===================== A :\n"); - print(" mA : "); print(mA); print("\n"); - print(" gA : "); print(gA); print("\n"); - print("tgA : "); print(tgA); print("\n"); - print("fragment_A : "); print(fragment_A); print("\n\n"); + print(" mA : "); + print(mA); + print("\n"); + print(" gA : "); + print(gA); + print("\n"); + print("tgA : "); + print(tgA); + print("\n"); + print("fragment_A : "); + print(fragment_A); + print("\n\n"); } #endif @@ -104,10 +114,18 @@ void gemm_device(TA const *A, TB const *B, TC *C, uint32_t m, uint32_t n, if (thread(LOG_THREAD)) { print("===================== B :\n"); - print(" mB : "); print(mB); print("\n"); - print(" gB : "); print(gB); print("\n"); - print("tgB : "); print(tgB); print("\n"); - print("fragment_B : "); print(fragment_B); print("\n\n"); + print(" mB : "); + print(mB); + print("\n"); + print(" gB : "); + print(gB); + print("\n"); + print("tgB : "); + print(tgB); + print("\n"); + print("fragment_B : "); + print(fragment_B); + print("\n\n"); } #endif @@ -115,10 +133,18 @@ void gemm_device(TA const *A, TB const *B, TC *C, uint32_t m, uint32_t n, if (thread(LOG_THREAD)) { print("===================== C :\n"); - print(" mC : "); print(mC); print("\n"); - print(" gC : "); print(gC); print("\n"); - print("tgC : "); print(tgC); print("\n"); - print("fragment_C : "); print(fragment_C); print("\n\n"); + print(" mC : "); + print(mC); + print("\n"); + print(" gC : "); + print(gC); + print("\n"); + print("tgC : "); + print(tgC); + print("\n"); + print("fragment_C : "); + print(fragment_C); + print("\n\n"); } #endif @@ -131,7 +157,7 @@ void gemm_device(TA const *A, TB const *B, TC *C, uint32_t m, uint32_t n, copy(kB, fragment_B); // Compute gemm on mma-partitioned smem - gemm(mma, fragment_A, fragment_B, fragment_C); + cute::gemm(mma, fragment_A, fragment_B, fragment_C); } copy(fragment_C, tgC); @@ -155,63 +181,26 @@ void gemm(int m, int n, int k, TA *A, TB *B, TC *C) { A, B, C, m, n, k); } -template -void verify(uint32_t m, uint32_t n, uint32_t k, atype *A, btype *B, ctype *C, - ctype *D) { - std::vector h_D(m * n); - - syclcompat::memcpy(h_D.data(), D, m * n); - - int cnt = 0; - for (int i = 0; i < m; i++) { - for (int j = 0; j < n; j++) { - for (int z = 0; z < k; z++) { - C[i * n + j] += A[i * k + z] * B[z * n + j]; - } - - auto error = abs((C[i * n + j] - h_D.data()[i * n + j]) / - (float)h_D.data()[i * n + j]); - if (error > 0.01f) { - cnt++; - } - } - } - - EXPECT_EQ(cnt, 0); -} - -template static void fill_matrix(std::vector &M) { - std::random_device dev; - std::mt19937 rng(dev()); - std::uniform_real_distribution dist((T)0.0, (T)4.0); - std::generate(std::begin(M), std::end(M), - [&] { return static_cast(dist(rng)); }); -} - template void MMA_Test(int m, int n, int k) { - std::vector h_A(m * k); - std::vector h_B(n * k); - std::vector h_C(m * n); - h_C.clear(); + cutlass::host_vector h_A(m * k); + cutlass::host_vector h_B(n * k); + cutlass::host_vector h_C(m * n); fill_matrix(h_A); fill_matrix(h_B); - auto d_A = syclcompat::malloc(m * k); - auto d_B = syclcompat::malloc(k * n); - auto d_C = syclcompat::malloc(m * n); - - syclcompat::memcpy(d_A, h_A.data(), m * k); - syclcompat::memcpy(d_B, h_B.data(), k * n); - syclcompat::memcpy(d_C, h_C.data(), m * n); + cutlass::device_vector d_A = h_A; + cutlass::device_vector d_B = h_B; + cutlass::device_vector d_C = h_C; - gemm(m, n, k, d_A, - d_B, d_C); + ::gemm( + m, n, k, d_A.data(), d_B.data(), d_C.data()); syclcompat::wait(); - verify(m, n, k, h_A.data(), h_B.data(), h_C.data(), d_C); + h_C = d_C; + verify(m, n, k, h_A.data(), h_B.data(), h_C.data()); } TEST(PVC_CuTe_Xe, MMA_XE_8x16x32_S32S8S8S32_TT) { @@ -298,3 +287,23 @@ TEST(PVC_CuTe_Xe, FMA_XE_UniversalFMA_F32F32F32F32) { MMA_Test, 64, 64, 8, 16, 16, float, float, float>(512, 512, 256); } + +TEST(PVC_CuTe_Xe, MMA_XE_1x16x8_F32TF32TF32F32_TT) { + MMA_Test(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_XE_2x16x8_F32TF32TF32F32_TT) { + MMA_Test(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_XE_4x16x8_F32TF32TF32F32_TT) { + MMA_Test(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_XE_8x16x8_F32TF32TF32F32_TT) { + MMA_Test(512, 512, 256); +} diff --git a/test/unit/cute/intel_xe/gemm_utils.hpp b/test/unit/cute/intel_xe/utils.hpp old mode 100644 new mode 100755 similarity index 78% rename from test/unit/cute/intel_xe/gemm_utils.hpp rename to test/unit/cute/intel_xe/utils.hpp index 3742d457a..ec2f1132c --- a/test/unit/cute/intel_xe/gemm_utils.hpp +++ b/test/unit/cute/intel_xe/utils.hpp @@ -29,6 +29,10 @@ * **************************************************************************************************/ +#pragma once + +#include "cutlass/detail/layout.hpp" + #include #include #include @@ -36,7 +40,12 @@ #include "cutlass_unit_test.h" using namespace cute; +using namespace cute::detail; + using namespace cutlass; +using namespace cutlass::layout; +using namespace cutlass::detail; + using namespace syclcompat::experimental; #define SUBGROUP_SIZE (16) @@ -47,7 +56,7 @@ using namespace syclcompat::experimental; template void verify(uint32_t m, uint32_t n, uint32_t k, atype *A, btype *B, ctype *C, - bool row_a, bool row_b) { + bool row_a = true, bool row_b = true) { int cnt = 0; bool is_normal = true; @@ -68,7 +77,8 @@ void verify(uint32_t m, uint32_t n, uint32_t k, atype *A, btype *B, ctype *C, cnt++; } } else { - is_normal = false; + // TODO(codeplay): Assert that at least some values are non-zero. + if(!(expect == 0 && val == 0)) is_normal = false; } } } @@ -80,8 +90,26 @@ void verify(uint32_t m, uint32_t n, uint32_t k, atype *A, btype *B, ctype *C, template static void fill_matrix(cutlass::host_vector &M) { std::random_device dev; std::mt19937 rng(dev()); - std::uniform_real_distribution dist((T)0.0, (T)1.0); - for (int i = 0; i < M.size(); i++) M[i] = static_cast(dist(rng)); + + T start, end; + + if constexpr (std::is_same_v || std::is_same_v + || std::is_same_v || std::is_same_v) { + start = (T)0.0; + end = (T)1.0; + } else if constexpr (std::is_same_v) { + start = (T)(-5); + end = (T)5; + } else if constexpr (std::is_same_v) { + start = (T)0; + end = (T)5; + } else { + CUTE_STATIC_ASSERT(false, "you must set coreect start/end value to initialize data"); + } + + std::uniform_real_distribution dist((T)start, (T)end); + for (int i = 0; i < M.size(); i++) + M[i] = static_cast(dist(rng)); } template void run(uint32_t m, uint32_t n, uint32_t k) { @@ -113,7 +141,8 @@ template void run(uint32_t m, uint32_t n, uint32_t k) { d_A.data(), d_B.data(), d_C.data(), m, n, k); syclcompat::wait(); + h_C = d_C; - verify(m, n, k, h_A.data(), h_B.data(), h_C.data(), - kernel::is_a_row_major, kernel::is_b_row_major); + verify(m, n, k, h_A.data(), h_B.data(), h_C.data(), kernel::is_a_row_major, + kernel::is_b_row_major); }