Skip to content

Commit

Permalink
[Softmax] Add online softmax according to Nvidia Paper (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
bear-zd authored Oct 2, 2024
1 parent 3f5ace3 commit 5ae3c08
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 0 deletions.
107 changes: 107 additions & 0 deletions softmax/softmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,35 @@
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
#define LDST128BITS(value) (reinterpret_cast<float4*>(&(value))[0])
// DS required for Online Softmax
struct __align__(8) MD
{
float m;
float d;
};

// -------------------------------------- FP32 --------------------------------------
// Warp Reduce for Online Softmax

template<const int kWarpSize = WARP_SIZE >
__device__ __forceinline__ MD warp_reduce_md_op(MD value) {
unsigned int mask = 0xffffffff;
#pragma unroll
for(int stride = kWarpSize >> 1; stride >= 1; stride >>= 1) {
MD other;
other.m = __shfl_xor_sync(mask, value.m, stride);
other.d = __shfl_xor_sync(mask, value.d, stride);

bool value_bigger = (value.m > other.m);
MD bigger_m = value_bigger ? value : other;
MD smaller_m = value_bigger ? other : value;

value.d = bigger_m.d + smaller_m.d * __expf(smaller_m.m - bigger_m.m);
value.m = bigger_m.m;
}
return value;
}

// Warp Reduce Sum
template<const int kWarpSize = WARP_SIZE>
__device__ __forceinline__ float warp_reduce_sum_f32(float val) {
Expand Down Expand Up @@ -289,6 +316,40 @@ __global__ void safe_softmax_f16x8_pack_f32_per_token_kernel(half* x, half* y, i
// TODO: support non 8-multiple K here
}

template<const int NUM_THREADS = 256 >
__global__ void online_softmax_f32_per_token_kernel(const float* x, float* y, int N) {

int local_tid = threadIdx.x;
int global_tid = blockIdx.x * NUM_THREADS + threadIdx.x;
const int WAPR_NUM = NUM_THREADS / WARP_SIZE;
int warp_id = local_tid / WARP_SIZE;
int lane_id = local_tid % WARP_SIZE;
MD val;
val.m = global_tid < N ? x[global_tid] : -FLT_MAX;
val.d = global_tid < N ? 1.0f : 0.0f;

__shared__ MD shared[ WAPR_NUM ];
MD res = warp_reduce_md_op<WARP_SIZE>(val);

if (lane_id == 0) shared[warp_id] = res;
__syncthreads();

if (local_tid < WARP_SIZE) {
MD block_res = shared[local_tid];
block_res = warp_reduce_md_op<WAPR_NUM>(block_res);
if (local_tid == 0) {
shared[0] = block_res;
}
}
__syncthreads();

MD final_res = shared[0];
float d_total_inverse = __fdividef(1.0f, final_res.d);
if (global_tid < N) {
y[global_tid] = __expf(x[global_tid] - final_res.m) * d_total_inverse;
}
}

// --------------------- PyTorch bindings for custom kernel -----------------------
#define STRINGFY(str) #str
#define TORCH_BINDING_COMMON_EXTENSION(func) \
Expand Down Expand Up @@ -440,6 +501,41 @@ safe_softmax_f32_per_token_kernel<(H)><<<grid, block>>>( \
break; \
}

// online softmax per token
#define LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(H) \
online_softmax_f32_per_token_kernel<(H)><<<grid, block>>>( \
reinterpret_cast<float*>(x.data_ptr()), \
reinterpret_cast<float*>(y.data_ptr()), \
N);

#define DISPATCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(S, H) \
dim3 block((H)); \
dim3 grid((S)); \
switch ((H)) \
{ \
case 32: \
LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(32) \
break; \
case 64: \
LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(64) \
break; \
case 128: \
LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(128) \
break; \
case 256: \
LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(256) \
break; \
case 512: \
LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(512) \
break; \
case 1024: \
LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(1024) \
break; \
default: \
throw std::runtime_error( \
"only support H: 64/128/256/512/1024"); \
break; \
}
#define LANUCH_SAFE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(H) \
safe_softmax_f32x4_per_token_kernel<(H)/4><<< \
grid, block>>>( \
Expand Down Expand Up @@ -674,6 +770,16 @@ void safe_softmax_f16x8_pack_f32_per_token(torch::Tensor x, torch::Tensor y) {
DISPATCH_SATE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL(S, H)
}

void online_softmax_f32_per_token(torch::Tensor x, torch::Tensor y) {
CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32)
CHECK_TORCH_TENSOR_DTYPE(y, torch::kFloat32)
CHECK_TORCH_TENSOR_SHAPE(x, y)
const int S = x.size(0); // seqlens
const int H = x.size(1); // head size/kv_len
const int N = S * H;
DISPATCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(S, H)
}

// grid memory fence fp32
TORCH_BINDING_SOFTMAX(f32, torch::kFloat32, float, 1)
TORCH_BINDING_SOFTMAX(f32x4, torch::kFloat32, float, 4)
Expand All @@ -688,4 +794,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f16_f32_per_token)
TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f16x2_f32_per_token)
TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f16x8_pack_f32_per_token)
TORCH_BINDING_COMMON_EXTENSION(online_softmax_f32_per_token)
}
3 changes: 3 additions & 0 deletions softmax/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
run_benchmark(lib.safe_softmax_f32_per_token, x, "f32(safe)", out)
run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out)
run_benchmark(lib.online_softmax_f32_per_token, x, "f32(online)", out)
run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)")

print("-" * 100)
Expand All @@ -99,6 +100,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
run_benchmark(lib.safe_softmax_f32_per_token, x, "f32(safe)", out)
run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out)
run_benchmark(lib.online_softmax_f32_per_token, x, "f32(online)", out)
run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)")

print("-" * 100)
Expand All @@ -121,6 +123,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
run_benchmark(lib.safe_softmax_f32_per_token, x, "f32(safe)", out)
run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out)
run_benchmark(lib.online_softmax_f32_per_token, x, "f32(online)", out)
run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)")

print("-" * 100)
Expand Down

0 comments on commit 5ae3c08

Please sign in to comment.