Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][AMD][Kernel][Quantization] Add fp8 and int8 support for Triton FAv2 kernel #12534

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
49 changes: 49 additions & 0 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
#include "cuda_compat.h"
#include "dispatch_utils.h"

#ifdef USE_ROCM
#include "quantization/fp8/amd/hip_float8.h"
#endif

namespace vllm {

template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
Expand All @@ -31,6 +35,24 @@ __global__ void act_and_mul_kernel(
}
}

// Scaled activation and gating kernel template.
#ifdef USE_ROCM
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void scaled_act_and_mul_kernel(
c10::Float8_e4m3fnuz* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d, const float scale) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
float r = ACT_FN(x) * y * scale;
out[token_idx * d + idx] = c10::Float8_e4m3fnuz(
hip_fp8(r).data, c10::Float8_e4m3fnuz::from_bits());
}
}
#endif

template <typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x)
Expand Down Expand Up @@ -79,6 +101,25 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
input.data_ptr<scalar_t>(), d); \
});

// Launch activation and gating kernel.
#ifdef USE_ROCM
#define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "scaled_act_and_mul_kernel", [&] { \
vllm::scaled_act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>( \
out.data_ptr<c10::Float8_e4m3fnuz>(), \
input.data_ptr<scalar_t>(), d, \
1.0 / (*scale.data_ptr<float>())); \
});
#endif

void silu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
Expand All @@ -93,6 +134,14 @@ void mul_and_silu(torch::Tensor& out, // [..., d]
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, false);
}

void scaled_silu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., 2 * d]
torch::Tensor& scale) {
#ifdef USE_ROCM
LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
#endif
}

void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
Expand Down
2 changes: 1 addition & 1 deletion csrc/attention/paged_attention_v1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -193,4 +193,4 @@ void paged_attention_v1(
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
#undef DIVIDE_ROUND_UP
2 changes: 1 addition & 1 deletion csrc/attention/paged_attention_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -203,4 +203,4 @@ void paged_attention_v2(
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
#undef DIVIDE_ROUND_UP
3 changes: 3 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

void mul_and_silu(torch::Tensor& out, torch::Tensor& input);

void scaled_silu_and_mul(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale);

void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
Expand Down
Loading
Loading