This is a Julia implementation of the Flash Attention algorithm.
using FlashAttention, CUDA
Q = CUDA.randn(Float16, 64, 1024, 48, 3);
K = CUDA.randn(Float16, 64, 1024, 48, 3);
V = CUDA.randn(Float16, 64, 1024, 48, 3);
flash_attention(Q,K,V)
Please refer to the file flash_attention.ncu-rep
. This is not the fastest implementation for
- we do not use tensor cores as in the C++ implmentation,
- CUDA.jl doese not yet support asynchronous copy from global memory to shared memory, and
- this kernel's theoretical occupancy (12.5%) is limited by the required amount of shared memory.
I plan to implement it in the future using MoYe.jl to achieve competitive performance.