-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Coord refactor #186
base: sycl-develop
Are you sure you want to change the base?
Coord refactor #186
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work @t4c1 - a few small things I spotted.
auto | ||
get_pvc_tensor(GShape const& g_shape) const { | ||
static_assert(rank(GShape{}) == 3, "mismatch rank"); | ||
return make_counting_tensor(make_layout(g_shape, make_stride(E<0>(), E<1>(), E<2>()))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_tma_tensor
uses g_stride_
for the 2nd arg to make_layout
here. Is there any loss of generality with this simpler approach?
constexpr int dtype_size = sizeof(dtype); | ||
constexpr int bits_in_byte = 8; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cutlass provides cutlass::sizeof_bits<dtype>
for this
static_assert(is_rmem<TS>::value); | ||
static_assert(size(SLayout{}) * dtype_size * bits_in_byte == size<1>(typename Traits_ST_t::SrcLayout{}), | ||
"Src tensor size does not match copy atom size"); | ||
static_assert(size(DLayout{}) * dtype_size * bits_in_byte == size<1>(typename Traits_ST_t::DstLayout{}), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As above, use cutlass::sizeof_bits<dtype>
I think.
@@ -137,12 +137,31 @@ struct CollectiveMma< | |||
using traits_load_B = Copy_Traits<GmemTiledCopyB, StrideB>; | |||
using atom_load_B = Copy_Atom<traits_load_B, ElementB>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the changes from this file need to be copied over to xe_mma_mixed_input.hpp
. I am getting local failure of ninja test_unit_gemm_device
using TensorMKL = decltype(make_tensor(make_gmem_ptr(static_cast<ElementA const*>(nullptr)), make_shape(0,0,0), StrideA{})); //(m, k) | ||
using TensorNKL = decltype(make_tensor(make_gmem_ptr(static_cast<ElementB const*>(nullptr)), make_shape(0,0,0), StrideB{})); //(n, k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused
|
||
// Instantiate the MMA object and get thread slice | ||
TiledMma tiled_mma; | ||
auto thr_mma = tiled_mma.get_slice(thread_idx); | ||
// To make all threads in a warp have the same global tensors pass in the index of thread 0 in each warp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have a TODO(Codeplay):
here to fix this later?
Tensor tArA = thr_copy_A2.retile_D(tCrA); | ||
Tensor tBrB = thr_copy_B2.retile_D(tCrB); | ||
|
||
// Retile global tile for copies | ||
Tensor tAgA = thr_copy_A2.retile_S(tCgA); | ||
Tensor tBgB = thr_copy_B2.retile_S(tCgB); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
retile_D
and retile_S
do the same thing by the way. Not sure if that affects what's going on here - but I don't think I've seen both used anywhere before.
Tensor g_cta_D_mnl = local_tile(mD_mnl, CtaTileMNK{}, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) | ||
|
||
// Slice to get the tile this CTA is responsible for // (BLK_M,BLK_N) | ||
Tensor g_cta_D = g_cta_D_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering here, if it should be possible to avoid this and have something like
Tensor g_cta_D_mnl = local_tile(mD_mnl, CtaTileMNK{}, make_coord(m_coord,n_coord,l_coord), Step<_1,_1, X>{});
Tensor gD_mnl = local_tile(g_cta_D, SubgroupTileShape{}, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) | ||
|
||
// Slice to get the tile this warp is responsible for | ||
Tensor gD = gD_mnl(_,_,m_sg,n_sg); // (BLK_M,BLK_N) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
|
||
// Instantiate the MMA object and get thread slice | ||
TiledMma tiled_mma; | ||
auto thr_mma = tiled_mma.get_slice(thread_idx); | ||
// To make all threads in a warp have the same global tensors pass in the index of thread 0 in each warp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// To make all threads in a warp have the same global tensors pass in the index of thread 0 in each warp | |
// To make all work items in a subgroup have the same global tensors pass in the index of work item 0 in each subgroup |
using SrcLayout = Layout<Shape <_16,Shape <_16, _2, _32>>, | ||
Stride<_0,Stride< _1,_256,_512>>>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the formatting is a bit messy here
@@ -310,12 +317,27 @@ class CollectiveEpilogue< | |||
auto sg_coord = make_coord(sg_m_coord, sg_n_coord, k_coord, l_coord); | |||
|
|||
bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); | |||
|
|||
// Represent the full output tensor | |||
Tensor mD_mnl = params.xe_store_d.get_pvc_tensor(make_shape(M,N,L)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be a counting tensor of D I believe, so maybe cD
would be more appropriate? (I'm not really sure)
@@ -238,6 +244,15 @@ struct XE_2D_LD_Unpack { | |||
make_layout(t_shape, t_stride)); | |||
} | |||
|
|||
// Generate the PVC coord tensor | |||
template <class GShape> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems unrelated to the class it's in. Maybe it shouldn't be a part of copy traits?
// Tile the output tensor per CTA | ||
Tensor g_cta_D_mnl = local_tile(mD_mnl, CtaTileMNK{}, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) | ||
|
||
// Slice to get the tile this CTA is responsible for // (BLK_M,BLK_N) | ||
Tensor g_cta_D = g_cta_D_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Tile the output tensor per CTA | |
Tensor g_cta_D_mnl = local_tile(mD_mnl, CtaTileMNK{}, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) | |
// Slice to get the tile this CTA is responsible for // (BLK_M,BLK_N) | |
Tensor g_cta_D = g_cta_D_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) | |
// Tile the output tensor per CTA | |
Tensor g_cta_D = local_tile(mD_mnl, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (BLK_M,BLK_N) |
I think this is simpler.
Maybe it should be cta_cD
?
// Tile the output tensor per warp | ||
Tensor gD_mnl = local_tile(g_cta_D, SubgroupTileShape{}, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) | ||
|
||
// Slice to get the tile this warp is responsible for | ||
Tensor gD = gD_mnl(_,_,m_sg,n_sg); // (BLK_M,BLK_N) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Tile the output tensor per warp | |
Tensor gD_mnl = local_tile(g_cta_D, SubgroupTileShape{}, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) | |
// Slice to get the tile this warp is responsible for | |
Tensor gD = gD_mnl(_,_,m_sg,n_sg); // (BLK_M,BLK_N) | |
// Tile the output tensor per warp | |
Tensor gD = local_tile(g_cta_D, SubgroupTileShape{}, make_coord(m_sg,n_sg)); // (SG_M, SG_N) |
I think this is correct too
auto gA_mk = local_tile(mA_mk, blk_shape, make_coord(_,_,_), Step<_1, X, _1>{}); | ||
auto gB_nk = local_tile(mB_nk, blk_shape, make_coord(_,_,_), Step< X, _1, _1>{}); | ||
|
||
// Slice with m_coord and n_coord | ||
Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k) | ||
Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto gA_mk = local_tile(mA_mk, blk_shape, make_coord(_,_,_), Step<_1, X, _1>{}); | |
auto gB_nk = local_tile(mB_nk, blk_shape, make_coord(_,_,_), Step< X, _1, _1>{}); | |
// Slice with m_coord and n_coord | |
Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k) | |
Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,k) | |
auto gA = local_tile(mA_mk, blk_shape, make_coord(m_coord,_,_), Step<_1, X, _1>{}); // (BLK_M,BLK_K,k) | |
auto gB = local_tile(mB_nk, blk_shape, make_coord(_,n_coord,_), Step< X, _1, _1>{}); // (BLK_N,BLK_K,k) |
Tensor tCrA = make_tensor<ElementA>(tCgA(_,_,_,0).shape()); | ||
Tensor tCrB = make_tensor<ElementB>(tCgB(_,_,_,0).shape(), make_stride(_1{}, shape<0>(tCgB) * shape<2>(tCgB), shape<0>(tCgB))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This too line does not seems to match what you are aiming to do
Tensor mA_mk = mA_mkl(_,_,l_coord); // (m,k) | ||
Tensor mB_nk = mB_nkl(_,_,l_coord); // (n,k) | ||
|
||
auto gA_mk = local_tile(mA_mk, blk_shape, make_coord(_,_,_), Step<_1, X, _1>{}); | ||
auto gB_nk = local_tile(mB_nk, blk_shape, make_coord(_,_,_), Step< X, _1, _1>{}); | ||
|
||
// Slice with m_coord and n_coord | ||
Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k) | ||
Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here I think it should be possible to say:
Tensor gA = local_tile(mA_mkl, blk_shape, make_coord(m_coord,_,l_coord), Step<_1, X, _1>{});
Tensor gB = local_tile(mB_nkl, blk_shape, make_coord(n_coord,_,l_coord), Step< X, _1, _1>{});
Refactor coordinates for PVC copies to be consistent with how copies for all CUDA GPUs are called.