diff --git a/lightllm/models/deepseek2/layer_infer/fused_moe.py b/lightllm/models/deepseek2/layer_infer/fused_moe.py index 96b8edbe..e60d0402 100644 --- a/lightllm/models/deepseek2/layer_infer/fused_moe.py +++ b/lightllm/models/deepseek2/layer_infer/fused_moe.py @@ -29,15 +29,11 @@ from lightllm.utils.log_utils import init_logger from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd -USE_VLLM = True try: from lightllm_vllm_kernel import moe_align_block_size as moe_align_block_size_kernel except ImportError: - from lightllm.models.deepseek2.layer_infer._custom_ops import ( - moe_align_block_size as moe_align_block_size_kernel_custom, - ) + from lightllm.models.deepseek2.layer_infer._custom_ops import moe_align_block_size as moe_align_block_size_kernel - USE_VLLM = False logger = init_logger(__name__) @@ -226,12 +222,7 @@ def moe_align_block_size( max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) expert_ids = alloc_tensor_func((max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device) num_tokens_post_pad = alloc_tensor_func((1), dtype=torch.int32, device=topk_ids.device) - if USE_VLLM: - moe_align_block_size_kernel(topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad) - else: - moe_align_block_size_kernel_custom( - topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad, alloc_tensor_func - ) + moe_align_block_size_kernel(topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad) return sorted_ids, expert_ids, num_tokens_post_pad diff --git a/lightllm/models/mixtral/layer_infer/_custom_ops.py b/lightllm/models/mixtral/layer_infer/_custom_ops.py index 532ead7f..b0e27ac1 100644 --- a/lightllm/models/mixtral/layer_infer/_custom_ops.py +++ b/lightllm/models/mixtral/layer_infer/_custom_ops.py @@ -35,9 +35,9 @@ def fused_topk( M, _ = hidden_states.shape - topk_weights = alloc_tensor_func(M, topk, dtype=torch.float32, device=hidden_states.device) - topk_ids = alloc_tensor_func(M, topk, dtype=torch.int32, device=hidden_states.device) - token_expert_indicies = alloc_tensor_func(M, topk, dtype=torch.int32, device=hidden_states.device) + topk_weights = alloc_tensor_func((M, topk), dtype=torch.float32, device=hidden_states.device) + topk_ids = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device) + token_expert_indicies = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device) topk_weights, topk_ids = topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output.float(), topk) del token_expert_indicies # Not used. Will be used in the future.