Skip to content

Commit

Permalink
ppl w8a8 support splitfuse mode. (#437)
Browse files Browse the repository at this point in the history
use w8a8 + splitfuse mode start a llama2 7b servers.

input:
data = {
'inputs': 'San Francisco is a',
'parameters' : {
'do_sample': False,
}
}
output:
{'generated_text': ['city in California, United States. Unterscheidung
between the two is not always clear-'], 'count_output_tokens': 16,
'finish_reason': 'length'}

test function is ok.

---------

Co-authored-by: wangzaijun <[email protected]>
Co-authored-by: shihaobai <[email protected]>
  • Loading branch information
3 people authored Jun 13, 2024
1 parent 7609618 commit a47942e
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
from lightllm.models.llama.splitfuse_infer_struct import LlamaSplitFuseInferStateInfo
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
from lightllm.common.basemodel import TransformerLayerInferActivationWeightQuantTpl
from lightllm.common.basemodel.cuda_kernel.ppl_awquant import (
Expand Down Expand Up @@ -220,6 +221,33 @@ def _token_ffn(self, input_embdings, infer_state: LlamaInferStateInfo, layer_wei
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return

def _splitfuse_attention(self, input_embding, infer_state: LlamaSplitFuseInferStateInfo, layer_weight):
# 因为 LlamaSplitFuseInferStateInfo 对象并没有 is_prefill 成员,但是后续的矩阵乘法算子入口
# 函数输入中需要使用到, 所以在开始的地方默认添加一个 is_prefill 成员,并设置为True.
infer_state.is_prefill = True

input1, token_scale, skip_out = self._awquant_att_norm(input_embding, infer_state, layer_weight)
cache_kv = self._pre_cache_kv(infer_state, layer_weight)
q, cache_kv = self._get_qkv(input1, cache_kv, token_scale, infer_state, layer_weight)
input1 = None
self._post_cache_kv(cache_kv, infer_state, layer_weight)
o = self._splitfuse_attention_kernel(q, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
if self.world_size_ > 1:
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)
input_embding.add_(o.view(-1, self.embed_dim_))
return

def _splitfuse_ffn(self, input_embdings, infer_state: LlamaSplitFuseInferStateInfo, layer_weight):
input1, token_scale, skip_out = self._awquant_ffn_norm(input_embdings, infer_state, layer_weight)
ffn_out = self._ffn(input1, token_scale, infer_state, layer_weight)
input1 = None
if self.world_size_ > 1:
dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False)
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return

def _awquant_matmul_ppl_int8_quant_dequant(
self, input, quant_weight_params, is_prefill, token_scale=None, out=None, bias=None, has_act=False
):
Expand Down
3 changes: 2 additions & 1 deletion lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ async def generate(
logger.debug(
f"req_id:{group_request_id},start:{start_time}s,first_token_cost:{first_token_cost_ms}ms\n"
f"total_cost_time:{total_cost_time_ms}ms,out_token_counter:{out_token_counter}\n"
f"mean_per_token_cost_time: {total_cost_time_ms/out_token_counter}ms"
f"mean_per_token_cost_time: {total_cost_time_ms/out_token_counter}ms\n"
f"prompt_token_num:{prompt_tokens}"
)
monitor.histogram_observe("lightllm_request_inference_duration", total_cost_time_ms)
monitor.histogram_observe(
Expand Down

0 comments on commit a47942e

Please sign in to comment.