Skip to content

Commit

Permalink
Implement full feature of copy/gemm for PVC backend (#174)
Browse files Browse the repository at this point in the history
* Implement full feature of copy/gemm for PVC backend

Implement Feature:
  1. Implement full features of copy/MMA for PVC backend
     We don't implement full copy/gemm functions before this commit because the cutlass cute copy/MMA API
     is not fully compatible with PVC backend. The register layout loaded by PVC subgroup intrinsic doesn't
     satisfy the cute::gemm requirement which leads to problems including but not limited to:
       (1) GEMM can only support specific combination of tile sizes and copy traits.
           GEMM functionality will be wrong if you try to change tile size configuration or copy traits.
           For example, the case "examples/sycl/pvc/pvc_gemm.cpp" will fail if you change sg_tile_k from 32 to 64.
           So we must retile the register data layout before cute::gemm.
       (2) We have to hardcode to change the register layout to satisfy the requirement of cutlass cute APIs.
           For example the data from “partition_fragment_B” need to be hardcoded.

  2. Support different GEMM layout and data type
       (1) Support different combinations of RowMajor and ColumnMajor for matrix A and B.
           Refer to test/unit/cute/intel_xe/gemm_data_type.cpp.
       (2) Add GEMM test case for int8/uint8/fp16/bf16. Refer to test/unit/cute/intel_xe/gemm_layout.cpp.

  This PR will implement above features and keep performance not dropped.

Refine Code
  1. Refine layout convention for gemm.
     For GEMM C = A x B; let A is (m, k, l), B is (n, k, l), C is (m, n, l), hide backend related differences
     inside implementation of PVC copy traits(copy_traits_xe.hpp), make it easier for upper-level users to
     write code for Intel Xe GPU according to cutlass usage habits, don’t let user hardcode for Intel Xe GPU.
  2. Refine the API "get_pvc_tensor"
     Before this PR, we mix K-slicing and coordinate tensor together, which make the interface parameters
     unclear and difficult to understand. actuualy "K-slicing" is for MMA use, while "coordinate tensor" is only for copy,
     they are two things, we must keep them functionally independent, so we supply a helper function "append_pvc_tensor".

* misc refine

* Update copy_traits_xe.hpp

* use make_coord

* rename variable to make semantics clear

* enable tf32 gemm and some refactoring

* refine code

* fix some comments

* fix comments

* fix comments

* fix build error

* refine gemm, add retile_MMA API for xe

* fix build error

* fix flash atten build issue

* update

* fix flash attention validation issue

* update

* update

* update benchmark configurations

* fix xe visit bug

* fix prefetch issue

* fix validation issue of uint test

* Update test/unit/cute/intel_xe/utils.hpp

Co-authored-by: Joe Todd <[email protected]>

---------

Co-authored-by: Alejandro Acosta <[email protected]>
Co-authored-by: Joe Todd <[email protected]>
  • Loading branch information
3 people authored Jan 17, 2025
1 parent d4558aa commit 0692435
Show file tree
Hide file tree
Showing 32 changed files with 1,469 additions and 1,540 deletions.
36 changes: 36 additions & 0 deletions benchmarks/pvc/benchmarks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,44 @@ using PvcGemmBF16BF16FP32_RRR_5 = cutlass::gemm::device::GemmConfiguration<
XE_2D_U16x8x32_LD_N, XE_2D_U16x32x32_LD_V,
Scheduler::Gemm>;

using PvcGemmBF16BF16FP32_RCR_6 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
float, cutlass::layout::RowMajor,
float, Shape<_256, _256, _32>,
TiledMMA<MMAAtom, Layout<Shape<_8,_4,_1>>>,
XE_2D_U16x8x32_LD_N, XE_2D_U16x16x16_LD_T,
Scheduler::Gemm>;

using PvcGemmBF16BF16FP32_CRR_7 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, Shape<_256, _256, _32>,
TiledMMA<MMAAtom, Layout<Shape<_8,_4,_1>>>,
XE_2D_U16x16x16_LD_T, XE_2D_U16x32x32_LD_V,
Scheduler::Gemm>;

using PvcGemmBF16BF16FP32_CCR_8 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
float, cutlass::layout::RowMajor,
float, Shape<_256, _256, _32>,
TiledMMA<MMAAtom, Layout<Shape<_8,_4,_1>>>,
XE_2D_U16x16x16_LD_T, XE_2D_U16x16x16_LD_T,
Scheduler::Gemm>;

CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_1);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_2);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_3);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_4);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_5);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RCR_6);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_CRR_7);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_CCR_8);

using PvcGemmBF16BF16FP32_StreamK_RRR_1 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
Expand Down Expand Up @@ -123,6 +156,9 @@ static void register_benchmarks() {
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_3);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_4);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_5);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RCR_6);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_CRR_7);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_CCR_8);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_StreamK_RRR_1);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_SplitK_RRR_1);
}
3 changes: 3 additions & 0 deletions benchmarks/pvc/input.in
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ PvcGemmBF16BF16FP32_RRR_5 --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=16384 --n=
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=128 --n=4096
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=4096 --n=128
PvcGemmBF16BF16FP32_RRR_3 --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=128
PvcGemmBF16BF16FP32_RCR_6 --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=4096 --n=4096
PvcGemmBF16BF16FP32_CRR_7 --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=4096 --n=4096
PvcGemmBF16BF16FP32_CCR_8 --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=4096 --n=4096

PvcGemmBF16BF16FP32_StreamK_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=8192
PvcGemmBF16BF16FP32_StreamK_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=32768
Expand Down
2 changes: 1 addition & 1 deletion examples/sycl/pvc/flash_attention_v2/pvc_flash_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ struct ExampleRunner {

stride_Q = cutlass::make_cute_packed_stride(StrideQ{}, cute::make_shape(seq_len, head_size, batch * num_heads));
stride_K = cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len, head_size, batch * num_heads));
stride_V = cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(seq_len, head_size, batch * num_heads));
stride_V = cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size, seq_len, batch * num_heads));
stride_O = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len, head_size, batch * num_heads));
stride_LSE = cutlass::make_cute_packed_stride(StrideLSE{}, cute::make_shape(seq_len, 1, batch * num_heads));

Expand Down
23 changes: 9 additions & 14 deletions examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,9 @@ class CollectiveEpilogueAttention<
static_assert(cute::rank(StrideLSE{}) == 3, "StrideLSE must be rank-3: [batch, num_heads, seq_len]");

using Trait_O = Copy_Traits<GmemTiledCopyO>;
using XE_Copy_O = decltype(make_tiled_copy(Copy_Atom<Trait_O, ElementO>{}
.with(static_cast<ElementO const*>(nullptr), int32_t(0), int32_t(0), int32_t(0)),
Layout<Shape<_1, Int<SubgroupSize>>>{},
make_layout(make_shape(get<0>(typename Trait_O::Shape_MN{}),
get<1>(typename Trait_O::Shape_MN{}) / Int<SubgroupSize>{}))));
using XE_Copy_O = decltype(make_xe_2d_copy(Copy_Atom<Copy_Traits<CopyOpO, StrideO>, ElementO>{}.with(
make_tensor(make_gmem_ptr(static_cast<ElementO const*>(nullptr)), make_layout(make_shape(0, 0, 0), StrideO{}))),
Layout<Shape<_1, Int<SubgroupSize>>>{}));
private:
constexpr static bool is_destination_supported = not cute::is_void_v<ElementO>;

Expand Down Expand Up @@ -157,11 +155,9 @@ class CollectiveEpilogueAttention<
auto [batch, num_heads, seq_len, head_size] = problem_shape;

XE_Copy_O xe_store_o = {};
xe_store_o = make_tiled_copy(Copy_Atom<Copy_Traits<CopyOpO>, ElementO>{}.with(
args.ptr_O, head_size, seq_len, head_size),
Layout<Shape<_1, Int<SubgroupSize>>>{},
make_layout(make_shape(get<0>(typename Trait_O::Shape_MN{}),
get<1>(typename Trait_O::Shape_MN{}) / Int<SubgroupSize>{})));
xe_store_o = make_xe_2d_copy(Copy_Atom<Copy_Traits<CopyOpO, StrideO>, ElementO>{}.with(
make_tensor(make_gmem_ptr(static_cast<ElementO const*>(args.ptr_O)), make_layout(make_shape(seq_len, head_size, batch * num_heads), args.dO))),
Layout<Shape<_1, Int<SubgroupSize>>>{});

return {
FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace),
Expand Down Expand Up @@ -263,11 +259,10 @@ class CollectiveEpilogueAttention<
auto [batch, num_heads, seq_len, head_size] = problem_shape;

Tensor tOi = params.xe_store_o.get_pvc_tensor(
make_coord(m_offset, n_offset, 0),
make_shape(_, Int<FragsM>{}, Int<FragsN>{}, batch * num_heads),
make_stride(Int<get<0>(MmaAtomShape{})>{}, Int<get<1>(MmaAtomShape{})>{}, _1{}));
make_coord(m_offset, n_offset, l_coord),
make_shape(_, Int<FragsM>{}, Int<FragsN>{}));

copy(params.xe_store_o, out, tOi(_,_,_,l_coord));
copy(params.xe_store_o, out, tOi);

const int lse_offset = m_offset + l_coord * seq_len;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,37 +294,45 @@ class GemmUniversalAttention
const int item_id = thread_idx % SubgroupSize;
const int k_tile_count= head_size / get<1>(subgroup_shape);
//m, k
Tensor prefetch_iter_a = params.mainloop.gmem_prefetch_q.get_pvc_tensor(
Tensor prefetch_iter_2d_a = params.mainloop.gmem_prefetch_q.get_pvc_tensor(
make_coord(seq_coord + (((sub_group_id % ATOM_N) / get<1>(PrefetchQThrShape{}))* get<0>(PrefetchQTileSize{})), // iteration 0/M/Hight/vertical
(((sub_group_id % ATOM_N) % get<1>(PrefetchQThrShape{})) * get<1>(PrefetchQTileSize{})), // Iteration 1/K/Width/Horisontal
blk_l_coord),
append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count),
append<3>(make_shape(_, _), BLK_K), seq<0, 0, 1>{});
((sub_group_id % ATOM_N) % get<1>(PrefetchQThrShape{})) * get<1>(PrefetchQTileSize{}), // Iteration 1/K/Width/Horisontal
blk_l_coord),
make_shape(_1{}, _1{}, _1{}));
Tensor prefetch_iter_a = append_pvc_tensor<1>(prefetch_iter_2d_a, k_tile_count, BLK_K);
// append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count),
// append<3>(make_shape(_, _), BLK_K), seq<0, 0, 1>{});
// The Key point is 1 is horisontal and zero is vertical
// the iteration over K dimention of B matrix (head_size) should be :
auto iter_over_head_count = head_size / BLK_N;
// k, n
Tensor prefetch_iter_b = params.mainloop.gmem_prefetch_k.get_pvc_tensor(
Tensor prefetch_iter_2d_b = params.mainloop.gmem_prefetch_k.get_pvc_tensor(
make_coord(sub_group_id * get<0>(PrefetchKTileSize{}), // iteration 0/K/Hight/vertical
(sub_group_id % ATOM_N) * get<1>(PrefetchKTileSize{}), // iteration 1/N/W/Horisontal
blk_l_coord), // batch
// ?, ?, k, N swap k and n here to match cutlass
append<4>(make_shape(_1{}, _1{}, nblock_limit/*This is N*/), iter_over_head_count/* This is K*/), //(frag, iter_m, iter_n, iter_k)
make_shape(_1{}, _1{}, nblock_limit/*This is N*/));
// iter_over_head_count/* This is K*/), //(frag, iter_m, iter_n, iter_k)
// K, ?, N (The N should move along the N as get<0>(PrefetchKThrShape) load 32 each and we want 128 of N )
// The K should move along the dimmension of Block load as we lay 8x32 using the 8x1 shape for subgroups
// leading to load 64x32 of (K,N) per each prefetch (BLOCK_N SHows K DIM)
append<3>(make_shape(_, SG_N), BLK_N), seq<0, 1, 0>{}); // so 64 * iteration 0 (SG_N that is K which is vertical) and 32 * iteration 1 (N which is horisontal)
// append<3>(make_shape(_, SG_N), BLK_N), seq<0, 1, 0>{}); // so 64 * iteration 0 (SG_N that is K which is vertical) and 32 * iteration 1 (N which is horisontal)
// V is a transposed matrix, So here the Sequense length is consumed, it is transposed so the consumed dimension looks like B matrix
// Hence, the Head size is the fast moving dimention and horisontal and sequence length is vertical.
// The prefetch only move along the sequence lenth. Here we call sequence length K since it get consumed and head size N since it stay
Tensor prefetch_iter_v = params.mainloop.gmem_prefetch_v.get_pvc_tensor(

Tensor prefetch_iter_b = append_pvc_tensor<0>(prefetch_iter_2d_b, iter_over_head_count, BLK_N);

Tensor prefetch_iter_2d_v = params.mainloop.gmem_prefetch_v.get_pvc_tensor(
make_coord((sub_group_id / ATOM_N) * get<0>(PrefetchVTileSize{}), // iteration 0/K/Hight/vertical/ sequence lengh
head_size_coord, // iteration 1/N/W/Horisontal / Head size
head_size_coord, // iteration 1/N/W/Horisontal / Head size
blk_l_coord),
// We loop over the consuming dimension which is the iteration 0(N) here
append<4>(make_shape(_1{}, _1{}, _1{}), nblock_limit),
make_shape(_1{}, _1{}, _1{}));
// , nblock_limit),
// first one is to use the intrinsic along the vertical , Second one is N/M and third one is K
append<3>(make_shape(_, _), BLK_K), seq<0, 1, 0>{});
// append<3>(make_shape(_, _), BLK_K), seq<0, 1, 0>{});
Tensor prefetch_iter_v = append_pvc_tensor<0>(prefetch_iter_2d_v, nblock_limit, BLK_K);

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < k_tile_count; i++) {
Expand Down
Loading

0 comments on commit 0692435

Please sign in to comment.