Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Coord refactor #186

Open
wants to merge 11 commits into
base: sycl-develop
Choose a base branch
from
Open

Conversation

t4c1
Copy link
Collaborator

@t4c1 t4c1 commented Jan 17, 2025

Refactor coordinates for PVC copies to be consistent with how copies for all CUDA GPUs are called.

@t4c1 t4c1 marked this pull request as ready for review January 29, 2025 09:38
Copy link
Collaborator

@joeatodd joeatodd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work @t4c1 - a few small things I spotted.

auto
get_pvc_tensor(GShape const& g_shape) const {
static_assert(rank(GShape{}) == 3, "mismatch rank");
return make_counting_tensor(make_layout(g_shape, make_stride(E<0>(), E<1>(), E<2>())));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_tma_tensor uses g_stride_ for the 2nd arg to make_layout here. Is there any loss of generality with this simpler approach?

Comment on lines +170 to +171
constexpr int dtype_size = sizeof(dtype);
constexpr int bits_in_byte = 8;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cutlass provides cutlass::sizeof_bits<dtype> for this

Comment on lines +320 to +323
static_assert(is_rmem<TS>::value);
static_assert(size(SLayout{}) * dtype_size * bits_in_byte == size<1>(typename Traits_ST_t::SrcLayout{}),
"Src tensor size does not match copy atom size");
static_assert(size(DLayout{}) * dtype_size * bits_in_byte == size<1>(typename Traits_ST_t::DstLayout{}),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above, use cutlass::sizeof_bits<dtype> I think.

@@ -137,12 +137,31 @@ struct CollectiveMma<
using traits_load_B = Copy_Traits<GmemTiledCopyB, StrideB>;
using atom_load_B = Copy_Atom<traits_load_B, ElementB>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the changes from this file need to be copied over to xe_mma_mixed_input.hpp. I am getting local failure of ninja test_unit_gemm_device

Comment on lines 151 to 152
using TensorMKL = decltype(make_tensor(make_gmem_ptr(static_cast<ElementA const*>(nullptr)), make_shape(0,0,0), StrideA{})); //(m, k)
using TensorNKL = decltype(make_tensor(make_gmem_ptr(static_cast<ElementB const*>(nullptr)), make_shape(0,0,0), StrideB{})); //(n, k)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused


// Instantiate the MMA object and get thread slice
TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_slice(thread_idx);
// To make all threads in a warp have the same global tensors pass in the index of thread 0 in each warp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a TODO(Codeplay): here to fix this later?

Comment on lines +256 to +261
Tensor tArA = thr_copy_A2.retile_D(tCrA);
Tensor tBrB = thr_copy_B2.retile_D(tCrB);

// Retile global tile for copies
Tensor tAgA = thr_copy_A2.retile_S(tCgA);
Tensor tBgB = thr_copy_B2.retile_S(tCgB);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

retile_D and retile_S do the same thing by the way. Not sure if that affects what's going on here - but I don't think I've seen both used anywhere before.

Comment on lines +325 to +328
Tensor g_cta_D_mnl = local_tile(mD_mnl, CtaTileMNK{}, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)

// Slice to get the tile this CTA is responsible for // (BLK_M,BLK_N)
Tensor g_cta_D = g_cta_D_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering here, if it should be possible to avoid this and have something like

 Tensor g_cta_D_mnl  = local_tile(mD_mnl, CtaTileMNK{}, make_coord(m_coord,n_coord,l_coord), Step<_1,_1, X>{}); 

Comment on lines +331 to +334
Tensor gD_mnl = local_tile(g_cta_D, SubgroupTileShape{}, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)

// Slice to get the tile this warp is responsible for
Tensor gD = gD_mnl(_,_,m_sg,n_sg); // (BLK_M,BLK_N)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here


// Instantiate the MMA object and get thread slice
TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_slice(thread_idx);
// To make all threads in a warp have the same global tensors pass in the index of thread 0 in each warp
Copy link
Collaborator

@mehdi-goli mehdi-goli Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// To make all threads in a warp have the same global tensors pass in the index of thread 0 in each warp
// To make all work items in a subgroup have the same global tensors pass in the index of work item 0 in each subgroup

Comment on lines +1027 to +1028
using SrcLayout = Layout<Shape <_16,Shape <_16, _2, _32>>,
Stride<_0,Stride< _1,_256,_512>>>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the formatting is a bit messy here

@@ -310,12 +317,27 @@ class CollectiveEpilogue<
auto sg_coord = make_coord(sg_m_coord, sg_n_coord, k_coord, l_coord);

bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed();

// Represent the full output tensor
Tensor mD_mnl = params.xe_store_d.get_pvc_tensor(make_shape(M,N,L));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be a counting tensor of D I believe, so maybe cD would be more appropriate? (I'm not really sure)

@@ -238,6 +244,15 @@ struct XE_2D_LD_Unpack {
make_layout(t_shape, t_stride));
}

// Generate the PVC coord tensor
template <class GShape>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems unrelated to the class it's in. Maybe it shouldn't be a part of copy traits?

Comment on lines +324 to +328
// Tile the output tensor per CTA
Tensor g_cta_D_mnl = local_tile(mD_mnl, CtaTileMNK{}, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)

// Slice to get the tile this CTA is responsible for // (BLK_M,BLK_N)
Tensor g_cta_D = g_cta_D_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Tile the output tensor per CTA
Tensor g_cta_D_mnl = local_tile(mD_mnl, CtaTileMNK{}, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
// Slice to get the tile this CTA is responsible for // (BLK_M,BLK_N)
Tensor g_cta_D = g_cta_D_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
// Tile the output tensor per CTA
Tensor g_cta_D = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (BLK_M,BLK_N)

I think this is simpler.
Maybe it should be cta_cD?

Comment on lines +330 to +334
// Tile the output tensor per warp
Tensor gD_mnl = local_tile(g_cta_D, SubgroupTileShape{}, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)

// Slice to get the tile this warp is responsible for
Tensor gD = gD_mnl(_,_,m_sg,n_sg); // (BLK_M,BLK_N)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Tile the output tensor per warp
Tensor gD_mnl = local_tile(g_cta_D, SubgroupTileShape{}, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
// Slice to get the tile this warp is responsible for
Tensor gD = gD_mnl(_,_,m_sg,n_sg); // (BLK_M,BLK_N)
// Tile the output tensor per warp
Tensor gD = local_tile(g_cta_D, SubgroupTileShape{}, make_coord(m_sg,n_sg)); // (SG_M, SG_N)

I think this is correct too

Comment on lines +252 to +257
auto gA_mk = local_tile(mA_mk, blk_shape, make_coord(_,_,_), Step<_1, X, _1>{});
auto gB_nk = local_tile(mB_nk, blk_shape, make_coord(_,_,_), Step< X, _1, _1>{});

// Slice with m_coord and n_coord
Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k)
Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,k)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto gA_mk = local_tile(mA_mk, blk_shape, make_coord(_,_,_), Step<_1, X, _1>{});
auto gB_nk = local_tile(mB_nk, blk_shape, make_coord(_,_,_), Step< X, _1, _1>{});
// Slice with m_coord and n_coord
Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k)
Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,k)
auto gA = local_tile(mA_mk, blk_shape, make_coord(m_coord,_,_), Step<_1, X, _1>{}); // (BLK_M,BLK_K,k)
auto gB = local_tile(mB_nk, blk_shape, make_coord(_,n_coord,_), Step< X, _1, _1>{}); // (BLK_N,BLK_K,k)

Comment on lines +252 to +253
Tensor tCrA = make_tensor<ElementA>(tCgA(_,_,_,0).shape());
Tensor tCrB = make_tensor<ElementB>(tCgB(_,_,_,0).shape(), make_stride(_1{}, shape<0>(tCgB) * shape<2>(tCgB), shape<0>(tCgB)));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This too line does not seems to match what you are aiming to do

Comment on lines +249 to +257
Tensor mA_mk = mA_mkl(_,_,l_coord); // (m,k)
Tensor mB_nk = mB_nkl(_,_,l_coord); // (n,k)

auto gA_mk = local_tile(mA_mk, blk_shape, make_coord(_,_,_), Step<_1, X, _1>{});
auto gB_nk = local_tile(mB_nk, blk_shape, make_coord(_,_,_), Step< X, _1, _1>{});

// Slice with m_coord and n_coord
Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k)
Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,k)
Copy link
Collaborator

@mehdi-goli mehdi-goli Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here I think it should be possible to say:

Tensor gA = local_tile(mA_mkl, blk_shape, make_coord(m_coord,_,l_coord), Step<_1,  X, _1>{});                                          
Tensor gB = local_tile(mB_nkl, blk_shape, make_coord(n_coord,_,l_coord), Step< X, _1, _1>{});

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants