From 789637fd4c2966712ed54223d63b596941e6e3a7 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Thu, 24 Oct 2024 16:53:18 +0800 Subject: [PATCH] add lightllm_vllm_kernel (#576) --- lightllm/models/deepseek2/layer_infer/fused_moe.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/lightllm/models/deepseek2/layer_infer/fused_moe.py b/lightllm/models/deepseek2/layer_infer/fused_moe.py index df9e065e..2f95aa49 100644 --- a/lightllm/models/deepseek2/layer_infer/fused_moe.py +++ b/lightllm/models/deepseek2/layer_infer/fused_moe.py @@ -27,9 +27,13 @@ import triton import triton.language as tl from lightllm.utils.log_utils import init_logger -import lightllm.models.deepseek2.layer_infer._custom_ops as ops from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd +try: + from lightllm_vllm_kernel import moe_align_block_size as moe_align_block_size_kernel +except ImportError: + from lightllm.models.deepseek2.layer_infer._custom_ops import moe_align_block_size as moe_align_block_size_kernel + logger = init_logger(__name__) @@ -218,7 +222,7 @@ def moe_align_block_size( max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad) + moe_align_block_size_kernel(topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad) return sorted_ids, expert_ids, num_tokens_post_pad