Skip to content

Commit

Permalink
fix fp8
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
manman-ren committed Dec 5, 2024
1 parent 2867e2f commit 937b451
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion submodules/generative-recommenders
Submodule generative-recommenders updated 27 files
+10 −13 README.md
+0 −4 generative_recommenders/data/eval.py
+58 −2 generative_recommenders/indexing/candidate_index.py
+3 −2 generative_recommenders/indexing/mips_top_k.py
+2 −8 generative_recommenders/indexing/utils.py
+569 −0 generative_recommenders/modeling/similarity/mol.py
+89 −59 generative_recommenders/modeling/similarity_utils.py
+0 −8 generative_recommenders/ops/cpp/common.h
+0 −44 generative_recommenders/ops/cpp/complete_cumsum.cpp
+0 −49 generative_recommenders/ops/cpp/complete_cumsum.cu
+0 −49 generative_recommenders/ops/cpp/cpp_ops.cpp
+0 −40 generative_recommenders/ops/cpp/sort_kv_pairs_cuda.cpp
+0 −82 generative_recommenders/ops/cpp/sort_kv_pairs_cuda_kernels_template.cu
+0 −12 generative_recommenders/ops/cpp/sort_kv_pairs_cuda_kernels_template.h
+0 −196 generative_recommenders/ops/pytorch/jagged.py
+52 −152 generative_recommenders/ops/triton/triton_jagged.py
+0 −25 generative_recommenders/ops/triton/triton_layer_norm.py
+0 −41 generative_recommenders/ops/triton/triton_position.py
+28 −97 generative_recommenders/ops/triton/triton_ragged_hstu_attention.py
+0 −41 generative_recommenders/rails/indexing/candidate_index.py
+0 −133 generative_recommenders/rails/indexing/mol_top_k.py
+0 −80 generative_recommenders/rails/similarities/layers.py
+0 −50 generative_recommenders/rails/similarities/mol/embeddings_fn.py
+0 −96 generative_recommenders/rails/similarities/mol/item_embeddings_fn.py
+0 −161 generative_recommenders/rails/similarities/mol/query_embeddings_fn.py
+0 −383 generative_recommenders/rails/similarities/mol/similarity_fn.py
+4 −20 generative_recommenders/trainer/train.py
8 changes: 4 additions & 4 deletions tritonbench/kernels/triton_fused_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,10 +458,10 @@ def _attn_fwd_inner_ws(
num_warps=w,
)
)
for BM in [128] # 64, 128]
for BN in [128] # 64, 128]
for s in [3] # 3, 4, 7]
for w in [8] # 4, 8]
for BM in [64, 128]
for BN in [64, 128]
for s in [3, 4, 7]
for w in [4, 8]
]
# TMA, WS, and CompPipe
configsTmaWS = [
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/fp8_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def triton_flash_v2(
triton_q, triton_k, triton_v = self.triton_preprocess(q, k, v)
# full fp8 will be enabled if type of q,k,v is fp8
return lambda: triton_attention(
triton_q, triton_k, triton_v, False, self.sm_scale, "base"
triton_q, triton_k, triton_v, False, self.sm_scale, "base", "base"
)

def get_x_val(self, _example_inputs) -> Tuple[int, int, int, int]:
Expand Down

0 comments on commit 937b451

Please sign in to comment.