Skip to content

Commit

Permalink
delete test
Browse files Browse the repository at this point in the history
  • Loading branch information
sufubao committed Sep 25, 2024
1 parent 3854d32 commit 1019098
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 492 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,9 @@
import triton.language as tl
import math
import torch.nn.functional as F
from lightllm.models.llama.triton_kernel.context_flashattention_nopad_old import context_attention_fwd as context_attention_fwd_old
from lightllm.models.llama.triton_kernel.context_flashattention_nopad_old import context_attention_fwd_no_prompt_cache as context_attention_fwd_no_prompt_cache_old

TESLA = "Tesla" in torch.cuda.get_device_name(0)

kernels = {}

sum_cost_time = 0
call_cnt = 0
import time

@triton.jit
def _fwd_kernel(
Q,
Expand Down Expand Up @@ -164,11 +156,6 @@ def _fwd_kernel(
def context_attention_fwd(
q, k, v, o, b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, max_input_len, req_to_token_indexs
):
old_out = torch.zeros_like(o)
context_attention_fwd_old(q, k, v, old_out, b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, max_input_len, req_to_token_indexs)
global sum_cost_time, call_cnt
torch.cuda.synchronize()
sta_time = time.time()

BLOCK_M = 128 if not TESLA else 64
# shape constraints
Expand Down Expand Up @@ -219,16 +206,6 @@ def context_attention_fwd(
num_warps=num_warps,
num_stages=num_stages
)

torch.cuda.synchronize()
ed_time = time.time()

call_cnt += 1
if call_cnt != 1:
sum_cost_time += ed_time - sta_time
print(f"[DY]sum_cost_time: {sum_cost_time*1000}, cnt: {call_cnt}, avg:{sum_cost_time*1000/call_cnt}")

print(f"max: {torch.max(old_out - o)}")


@triton.jit
Expand Down Expand Up @@ -370,13 +347,6 @@ def _fwd_kernel_no_prompt_cache(
@torch.no_grad()
def context_attention_fwd_no_prompt_cache(q, k, v, o, b_start_loc, b_seq_len, max_input_len):

old_out = torch.zeros_like(o)
context_attention_fwd_no_prompt_cache_old(q, k, v, old_out, b_start_loc, b_seq_len, max_input_len)
global sum_cost_time, call_cnt

torch.cuda.synchronize()
sta_time = time.time()

BLOCK_M = 128 if not TESLA else 64
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
Expand Down Expand Up @@ -419,13 +389,4 @@ def context_attention_fwd_no_prompt_cache(q, k, v, o, b_start_loc, b_seq_len, ma
BLOCK_N=BLOCK_N,
num_warps=num_warps,
num_stages=num_stages,
)

torch.cuda.synchronize()
ed_time = time.time()

call_cnt += 1
sum_cost_time += ed_time - sta_time
print(f"sum_cost_time: {sum_cost_time*1000}, cnt: {call_cnt}, avg:{sum_cost_time*1000/call_cnt}")

print(f"max: {torch.max(old_out - o)}")
)
Loading

0 comments on commit 1019098

Please sign in to comment.