Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
hiworldwzj authored Oct 26, 2024
1 parent b145584 commit 82eb5a9
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 14 deletions.
13 changes: 2 additions & 11 deletions lightllm/models/deepseek2/layer_infer/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,11 @@
from lightllm.utils.log_utils import init_logger
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd

USE_VLLM = True
try:
from lightllm_vllm_kernel import moe_align_block_size as moe_align_block_size_kernel
except ImportError:
from lightllm.models.deepseek2.layer_infer._custom_ops import (
moe_align_block_size as moe_align_block_size_kernel_custom,
)
from lightllm.models.deepseek2.layer_infer._custom_ops import moe_align_block_size as moe_align_block_size_kernel

USE_VLLM = False

logger = init_logger(__name__)

Expand Down Expand Up @@ -226,12 +222,7 @@ def moe_align_block_size(
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
expert_ids = alloc_tensor_func((max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device)
num_tokens_post_pad = alloc_tensor_func((1), dtype=torch.int32, device=topk_ids.device)
if USE_VLLM:
moe_align_block_size_kernel(topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad)
else:
moe_align_block_size_kernel_custom(
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad, alloc_tensor_func
)
moe_align_block_size_kernel(topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad)
return sorted_ids, expert_ids, num_tokens_post_pad


Expand Down
6 changes: 3 additions & 3 deletions lightllm/models/mixtral/layer_infer/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def fused_topk(

M, _ = hidden_states.shape

topk_weights = alloc_tensor_func(M, topk, dtype=torch.float32, device=hidden_states.device)
topk_ids = alloc_tensor_func(M, topk, dtype=torch.int32, device=hidden_states.device)
token_expert_indicies = alloc_tensor_func(M, topk, dtype=torch.int32, device=hidden_states.device)
topk_weights = alloc_tensor_func((M, topk), dtype=torch.float32, device=hidden_states.device)
topk_ids = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device)
token_expert_indicies = alloc_tensor_func((M, topk), dtype=torch.int32, device=hidden_states.device)
topk_weights, topk_ids = topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output.float(), topk)
del token_expert_indicies # Not used. Will be used in the future.

Expand Down

0 comments on commit 82eb5a9

Please sign in to comment.