Skip to content

Commit

Permalink
Fix verification for larger sizes with Flash Attention (#192)
Browse files Browse the repository at this point in the history
Fix the result verification for larger sizes for Flash Attention example. It also removes the FusionCallbacks from the code since this implementation doesn't use any fusion::operators for the Epilogue.
  • Loading branch information
muhammad-tanvir-1211 authored Jan 22, 2025
1 parent 27d2b97 commit d6ab5aa
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 110 deletions.
174 changes: 86 additions & 88 deletions examples/sycl/pvc/flash_attention_v2/pvc_flash_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,123 +165,128 @@ struct ExampleRunner {

int batch_size = batch * num_heads;

cutlass::DeviceAllocation<ElementOutput> block_S;
block_S.reset(batch_size * seq_len * seq_len);

cutlass::TensorRef ref_Q(block_Q.get(), LayoutQ::packed({seq_len, head_size}));
cutlass::TensorRef ref_K(block_K.get(), LayoutK::packed({head_size, seq_len}));
cutlass::TensorRef ref_V(block_V.get(), LayoutV::packed({seq_len, head_size}));
cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len, seq_len}));
cutlass::TensorRef ref_O(block_ref_O.get(), LayoutO::packed({seq_len, head_size}));

cutlass::reference::device::GemmComplex(
{seq_len, seq_len, head_size},
1.f,
ref_Q,
cutlass::ComplexTransform::kNone,
ref_K,
cutlass::ComplexTransform::kNone,
0.f,
ref_S,
ref_S,
ElementAccumulator(0),
batch_size, // batch_count
seq_len * head_size, // batch_stride_Q
seq_len * head_size, // batch_stride_K
seq_len * seq_len, // batch_stride_S
seq_len * seq_len // batch_stride_S
);
// loop over the batch dimension to compute the output
// to avoid the risk of running out of device memory
for(int b = 0, offset = 0; b < batch_size; b++, offset += seq_len * head_size) {

cutlass::DeviceAllocation<ElementOutput> block_S;
block_S.reset(seq_len * seq_len);

cutlass::TensorRef ref_Q(block_Q.get() + offset, LayoutQ::packed({seq_len, head_size}));
cutlass::TensorRef ref_K(block_K.get() + offset, LayoutK::packed({head_size, seq_len}));
cutlass::TensorRef ref_V(block_V.get() + offset, LayoutV::packed({seq_len, head_size}));
cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len, seq_len}));
cutlass::TensorRef ref_O(block_ref_O.get() + offset, LayoutO::packed({seq_len, head_size}));

cutlass::reference::device::GemmComplex(
{seq_len, seq_len, head_size},
1.f,
ref_Q,
cutlass::ComplexTransform::kNone,
ref_K,
cutlass::ComplexTransform::kNone,
0.f,
ref_S,
ref_S,
ElementAccumulator(0),
1, // batch_count
seq_len * head_size, // batch_stride_Q
seq_len * head_size, // batch_stride_K
seq_len * seq_len, // batch_stride_S
seq_len * seq_len // batch_stride_S
);

syclcompat::wait();
syclcompat::wait();

std::vector<ElementOutput> host_S(batch_size * seq_len * seq_len);
syclcompat::memcpy<ElementOutput>(host_S.data(), block_S.get(), host_S.size());
syclcompat::wait();
std::vector<ElementOutput> host_S(seq_len * seq_len);
syclcompat::memcpy<ElementOutput>(host_S.data(), block_S.get(), host_S.size());
syclcompat::wait();

// delete this memory as it is no longer needed
block_S.reset();

if(is_causal) {
// apply mask to S
for (int b = 0; b < batch_size; b++) {
if(is_causal) {
// apply mask to S
for (int row = 0; row < seq_len; row++) {
for (int col = 0; col < seq_len; col++) {
if (col > row)
host_S[col + (b * seq_len + row) * seq_len] = -INFINITY;
host_S[col + row * seq_len] = -INFINITY;
}
}
}
}

// compute max element per row of S
std::vector<ElementOutput> max_vec(batch_size * seq_len, -INFINITY);
for (int b = 0; b < batch_size; b++) {
// compute max element per row of S
std::vector<ElementOutput> max_vec(seq_len, -INFINITY);
for (int row = 0; row < seq_len; row++) {
int idx = (b * seq_len + row) * seq_len;
int max_idx = b * seq_len + row;
int idx = row * seq_len;
int max_idx = row;
max_vec[max_idx] = host_S[idx++];
for (int col = 1; col < seq_len; col++, idx++) {
if (max_vec[max_idx] < host_S[idx])
max_vec[max_idx] = host_S[idx];
}
}
}

// compute exp of S
for (int b = 0; b < batch_size; b++) {
// compute exp of S
for (int row = 0; row < seq_len; row++) {
int idx = (b * seq_len + row) * seq_len;
int max_idx = b * seq_len + row;
int idx = row * seq_len;
int max_idx = row;
for (int col = 0; col < seq_len; col++, idx++) {
host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / sqrt(static_cast<ElementOutput>((head_size))));
}
}
}

// compute sum per row of S
std::vector<ElementOutput> sum_vec(batch_size * seq_len, ElementOutput{0});
for (int b = 0; b < batch_size; b++) {
// compute sum per row of S
std::vector<ElementOutput> sum_vec(seq_len, ElementOutput{0});
for (int row = 0; row < seq_len; row++) {
int idx = (b * seq_len + row) * seq_len;
int sum_idx = b * seq_len + row;
int idx = row * seq_len;
int sum_idx = row;
for (int col = 0; col < seq_len; col++, idx++) {
sum_vec[sum_idx] += host_S[idx];
}

//scale each row with the sum to compute softmax
idx = (b * seq_len + row) * seq_len;
sum_idx = b * seq_len + row;
idx = row * seq_len;
sum_idx = row;
for (int col = 0; col < seq_len; col++, idx++) {
host_S[idx] /= sum_vec[sum_idx];
}
}
}
}

std::vector<ElementV> host_P(host_S.size());
for(int p = 0; p < host_P.size(); p++) host_P[p] = static_cast<ElementV>(host_S[p]);
std::vector<ElementV> host_P(host_S.size());
for(int p = 0; p < host_P.size(); p++) host_P[p] = static_cast<ElementV>(host_S[p]);

cutlass::DeviceAllocation<ElementV> block_P;
block_P.reset(host_P.size());
cutlass::DeviceAllocation<ElementV> block_P;
block_P.reset(host_P.size());

syclcompat::memcpy<ElementV>(block_P.get(), host_P.data(), host_P.size());
syclcompat::wait();
syclcompat::memcpy<ElementV>(block_P.get(), host_P.data(), host_P.size());
syclcompat::wait();

cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len, seq_len}));

cutlass::reference::device::GemmComplex(
{seq_len, head_size, seq_len},
1.f,
ref_P,
cutlass::ComplexTransform::kNone,
ref_V,
cutlass::ComplexTransform::kNone,
0.f,
ref_O,
ref_O,
ElementAccumulator(0),
1, // batch_count
seq_len * seq_len, // batch_stride_P
seq_len * head_size, // batch_stride_V
seq_len * head_size, // batch_stride_O
seq_len * head_size // batch_stride_O
);

cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len, seq_len}));

cutlass::reference::device::GemmComplex(
{seq_len, head_size, seq_len},
1.f,
ref_P,
cutlass::ComplexTransform::kNone,
ref_V,
cutlass::ComplexTransform::kNone,
0.f,
ref_O,
ref_O,
ElementAccumulator(0),
batch_size, // batch_count
seq_len * seq_len, // batch_stride_P
seq_len * head_size, // batch_stride_V
seq_len * head_size, // batch_stride_O
seq_len * head_size // batch_stride_O
);
syclcompat::wait();
// delete this memory as it is no longer needed
block_P.reset();

}

syclcompat::wait();

Expand Down Expand Up @@ -347,7 +352,7 @@ struct ExampleRunner {
problem_size,
{block_Q.get(), stride_Q, block_K.get(), stride_K, block_V.get(), stride_V},
{options.softmax_scale},
{{1}, block_O.get(), stride_O, block_lse.get(), stride_LSE},
{block_O.get(), stride_O, block_lse.get(), stride_LSE},
hw_info
};

Expand Down Expand Up @@ -455,20 +460,13 @@ int main(int argc, const char** argv)
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVC<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue;

using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue,
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;

using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
decltype(tile_shape(TiledMma()))>;

using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogueAttention<
EpilogueDispatchPolicy,
TileShape,
ElementAccumulator,
cutlass::gemm::TagToStrideC_t<LayoutO>,
ElementOutput,
cutlass::gemm::TagToStrideC_t<LayoutLSE>,
FusionCallBacks,
XE_2D_U32x8x16_ST_N>;

if(options.is_causal) {
Expand Down
23 changes: 6 additions & 17 deletions examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ template <
class StrideO_,
class ElementLSE_,
class StrideLSE_,
class FusionCallbacks_,
class CopyOpO_
>
class CollectiveEpilogueAttention<
Expand All @@ -75,7 +74,6 @@ class CollectiveEpilogueAttention<
StrideO_,
ElementLSE_,
StrideLSE_,
FusionCallbacks_,
CopyOpO_
> {
public:
Expand All @@ -84,23 +82,21 @@ class CollectiveEpilogueAttention<
//
using DispatchPolicy = IntelPVCEpilogue;
using CtaTileMNK = CtaTileMNK_;
using FusionCallbacks = FusionCallbacks_;
using ElementO = ElementO_;
using ElementAccumulator = ElementO_;
using StrideO = StrideO_;
using ElementLSE = ElementLSE_;
using StrideLSE = StrideLSE_;
using CopyOpO = CopyOpO_;

using ThreadEpilogueOp = typename fusion::FusionCallbacksTraits<FusionCallbacks>::Operation;
using GmemTiledCopyO = CopyOpO;
using ElementOutput = typename FusionCallbacks::ElementOutput;
using ElementCompute = typename FusionCallbacks::ElementCompute;
using ElementOutput = ElementO_;
using ElementCompute = ElementO_;

static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;

static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]");
static_assert(cute::rank(StrideO{}) == 3, "StrideO must be rank-4: [batch, num_heads, seq_len, head_size]");
static_assert(cute::rank(StrideO{}) == 3, "StrideO must be rank-3: [seq_len, head_size, batch * num_heads]");
static_assert(cute::rank(StrideLSE{}) == 3, "StrideLSE must be rank-3: [batch, num_heads, seq_len]");

using Trait_O = Copy_Traits<GmemTiledCopyO>;
Expand All @@ -114,10 +110,7 @@ class CollectiveEpilogueAttention<

using EmptyType = cute::tuple<>;

struct TensorStorageImpl: cute::tuple<EmptyType, EmptyType> {
using FusionStorage = typename FusionCallbacks::SharedStorage;
FusionStorage thread;
};
struct TensorStorageImpl: cute::tuple<EmptyType, EmptyType> {};

struct SharedStorage {
using TensorStorage = TensorStorageImpl;
Expand All @@ -128,7 +121,6 @@ class CollectiveEpilogueAttention<

// Host side epilogue arguments
struct Arguments {
typename FusionCallbacks::Arguments thread{};
ElementO const* ptr_O;
StrideO dO;
ElementLSE* ptr_LSE;
Expand All @@ -137,7 +129,6 @@ class CollectiveEpilogueAttention<

// Device side epilogue params
struct Params {
typename FusionCallbacks::Params thread{};
XE_Copy_O xe_store_o;
ElementLSE* ptr_LSE;
};
Expand All @@ -160,7 +151,6 @@ class CollectiveEpilogueAttention<
Layout<Shape<_1, Int<SubgroupSize>>>{});

return {
FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace),
xe_store_o,
args.ptr_LSE
};
Expand Down Expand Up @@ -188,8 +178,8 @@ class CollectiveEpilogueAttention<
}

CUTLASS_HOST_DEVICE
CollectiveEpilogueAttention(Params const& params_, TensorStorage const& shared_storage_)
: params(params_), fusion_callbacks(params_.thread, shared_storage_.thread) {}
CollectiveEpilogueAttention(Params const& params_, TensorStorage const&)
: params(params_) {}

template <
class ProblemShape,
Expand Down Expand Up @@ -282,7 +272,6 @@ class CollectiveEpilogueAttention<

private:
Params const& params;
FusionCallbacks fusion_callbacks;
};


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@
inline x { assert(false); }
#endif

SYCL_DEVICE_SPV_SPLIT_BARRIER(void __spirv_ControlBarrierArriveINTEL(int, int, int));
SYCL_DEVICE_SPV_SPLIT_BARRIER(void __spirv_ControlBarrierWaitINTEL(int, int, int));
SYCL_DEVICE_SPV_SPLIT_BARRIER(void __spirv_ControlBarrierArriveINTEL(int execution_scope, int memory_scope, int memory_semantics));
SYCL_DEVICE_SPV_SPLIT_BARRIER(void __spirv_ControlBarrierWaitINTEL(int execution_scope, int memory_scope, int memory_semantics));

#undef SYCL_DEVICE_SPV_SPLIT_BARRIER
namespace cutlass::gemm::kernel {
Expand Down Expand Up @@ -233,9 +233,9 @@ class GemmUniversalAttention
auto seq_len = get<2>(params.problem_shape);
auto head_size = get<3>(params.problem_shape);
// Preconditions
static_assert(cute::rank(StrideQ{}) == 3, "StrideQ must be rank-4: [batch, num_heads, seq_len, head_size].");
static_assert(cute::rank(StrideK{}) == 3, "StrideK must be rank-4: [batch, num_heads, seq_len, head_size].");
static_assert(cute::rank(StrideV{}) == 3, "StrideV must be rank-4: [batch, num_heads, seq_len, head_size].");
static_assert(cute::rank(StrideQ{}) == 3, "StrideQ must be rank-3: [seq_len, head_size, batch * num_heads].");
static_assert(cute::rank(StrideK{}) == 3, "StrideK must be rank-3: [head_size, seq_len, batch * num_heads].");
static_assert(cute::rank(StrideV{}) == 3, "StrideV must be rank-3: [seq_len, head_size, batch * num_heads].");

int thread_idx = int(ThreadIdxX());
int sub_group_id = thread_idx / SubgroupSize;
Expand Down

0 comments on commit d6ab5aa

Please sign in to comment.