diff --git a/softmax/softmax.cu b/softmax/softmax.cu index a1096aab..eadf7c6b 100644 --- a/softmax/softmax.cu +++ b/softmax/softmax.cu @@ -16,8 +16,35 @@ #define HALF2(value) (reinterpret_cast(&(value))[0]) #define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0]) #define LDST128BITS(value) (reinterpret_cast(&(value))[0]) +// DS required for Online Softmax +struct __align__(8) MD +{ + float m; + float d; +}; // -------------------------------------- FP32 -------------------------------------- +// Warp Reduce for Online Softmax + +template +__device__ __forceinline__ MD warp_reduce_md_op(MD value) { + unsigned int mask = 0xffffffff; + #pragma unroll + for(int stride = kWarpSize >> 1; stride >= 1; stride >>= 1) { + MD other; + other.m = __shfl_xor_sync(mask, value.m, stride); + other.d = __shfl_xor_sync(mask, value.d, stride); + + bool value_bigger = (value.m > other.m); + MD bigger_m = value_bigger ? value : other; + MD smaller_m = value_bigger ? other : value; + + value.d = bigger_m.d + smaller_m.d * __expf(smaller_m.m - bigger_m.m); + value.m = bigger_m.m; + } + return value; +} + // Warp Reduce Sum template __device__ __forceinline__ float warp_reduce_sum_f32(float val) { @@ -289,6 +316,40 @@ __global__ void safe_softmax_f16x8_pack_f32_per_token_kernel(half* x, half* y, i // TODO: support non 8-multiple K here } +template +__global__ void online_softmax_f32_per_token_kernel(const float* x, float* y, int N) { + + int local_tid = threadIdx.x; + int global_tid = blockIdx.x * NUM_THREADS + threadIdx.x; + const int WAPR_NUM = NUM_THREADS / WARP_SIZE; + int warp_id = local_tid / WARP_SIZE; + int lane_id = local_tid % WARP_SIZE; + MD val; + val.m = global_tid < N ? x[global_tid] : -FLT_MAX; + val.d = global_tid < N ? 1.0f : 0.0f; + + __shared__ MD shared[ WAPR_NUM ]; + MD res = warp_reduce_md_op(val); + + if (lane_id == 0) shared[warp_id] = res; + __syncthreads(); + + if (local_tid < WARP_SIZE) { + MD block_res = shared[local_tid]; + block_res = warp_reduce_md_op(block_res); + if (local_tid == 0) { + shared[0] = block_res; + } + } + __syncthreads(); + + MD final_res = shared[0]; + float d_total_inverse = __fdividef(1.0f, final_res.d); + if (global_tid < N) { + y[global_tid] = __expf(x[global_tid] - final_res.m) * d_total_inverse; + } +} + // --------------------- PyTorch bindings for custom kernel ----------------------- #define STRINGFY(str) #str #define TORCH_BINDING_COMMON_EXTENSION(func) \ @@ -440,6 +501,41 @@ safe_softmax_f32_per_token_kernel<(H)><<>>( \ break; \ } +// online softmax per token +#define LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(H) \ +online_softmax_f32_per_token_kernel<(H)><<>>( \ + reinterpret_cast(x.data_ptr()), \ + reinterpret_cast(y.data_ptr()), \ + N); + +#define DISPATCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(S, H) \ + dim3 block((H)); \ + dim3 grid((S)); \ + switch ((H)) \ + { \ + case 32: \ + LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(32) \ + break; \ + case 64: \ + LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(64) \ + break; \ + case 128: \ + LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(128) \ + break; \ + case 256: \ + LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(256) \ + break; \ + case 512: \ + LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(512) \ + break; \ + case 1024: \ + LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(1024) \ + break; \ + default: \ + throw std::runtime_error( \ + "only support H: 64/128/256/512/1024"); \ + break; \ + } #define LANUCH_SAFE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(H) \ safe_softmax_f32x4_per_token_kernel<(H)/4><<< \ grid, block>>>( \ @@ -674,6 +770,16 @@ void safe_softmax_f16x8_pack_f32_per_token(torch::Tensor x, torch::Tensor y) { DISPATCH_SATE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(S, H) } +void online_softmax_f32_per_token(torch::Tensor x, torch::Tensor y) { + CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32) + CHECK_TORCH_TENSOR_DTYPE(y, torch::kFloat32) + CHECK_TORCH_TENSOR_SHAPE(x, y) + const int S = x.size(0); // seqlens + const int H = x.size(1); // head size/kv_len + const int N = S * H; + DISPATCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(S, H) +} + // grid memory fence fp32 TORCH_BINDING_SOFTMAX(f32, torch::kFloat32, float, 1) TORCH_BINDING_SOFTMAX(f32x4, torch::kFloat32, float, 4) @@ -688,4 +794,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f16_f32_per_token) TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f16x2_f32_per_token) TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f16x8_pack_f32_per_token) + TORCH_BINDING_COMMON_EXTENSION(online_softmax_f32_per_token) } diff --git a/softmax/softmax.py b/softmax/softmax.py index b2769161..de744a5c 100644 --- a/softmax/softmax.py +++ b/softmax/softmax.py @@ -77,6 +77,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor, run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out) run_benchmark(lib.safe_softmax_f32_per_token, x, "f32(safe)", out) run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out) +run_benchmark(lib.online_softmax_f32_per_token, x, "f32(online)", out) run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)") print("-" * 100) @@ -99,6 +100,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor, run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out) run_benchmark(lib.safe_softmax_f32_per_token, x, "f32(safe)", out) run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out) +run_benchmark(lib.online_softmax_f32_per_token, x, "f32(online)", out) run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)") print("-" * 100) @@ -121,6 +123,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor, run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out) run_benchmark(lib.safe_softmax_f32_per_token, x, "f32(safe)", out) run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out) +run_benchmark(lib.online_softmax_f32_per_token, x, "f32(online)", out) run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)") print("-" * 100)