From e58aa74809922d3d52ae9365acb4628951ffbd9c Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Thu, 31 Oct 2024 18:41:57 +0800 Subject: [PATCH] add first_token_constraint_mode (#599) Co-authored-by: wangzaijun --- lightllm/server/api_server.py | 8 ++- lightllm/server/router/manager.py | 1 + .../model_infer/mode_backend/__init__.py | 1 + .../impl_for_first_token_constraint_mode.py | 72 +++++++++++++++++++ .../server/router/model_infer/model_rpc.py | 4 ++ lightllm/server/router/req_queue/__init__.py | 2 + 6 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_first_token_constraint_mode.py diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index 00bfd047..f72ef1dc 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -435,7 +435,12 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode") parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode") parser.add_argument("--simple_constraint_mode", action="store_true", help="output constraint mode") - + parser.add_argument( + "--first_token_constraint_mode", + action="store_true", + help="""constraint the first token allowed range, + use env FIRST_ALLOWED_TOKENS to set the range, like FIRST_ALLOWED_TOKENS=1,2 ..""", + ) parser.add_argument( "--enable_multimodal", action="store_true", help="Whether or not to allow to load additional multimodal models." ) @@ -546,6 +551,7 @@ def main(): args.token_healing_mode, args.use_reward_model, args.return_all_prompt_logprobs, + args.first_token_constraint_mode, ].count(True) <= 1 # 部分模式目前还无法与dynamic_prompt_cache一起跑,to do。 if args.use_dynamic_prompt_cache: diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 58546b27..c7ef6dea 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -91,6 +91,7 @@ async def wait_to_model_ready(self): "max_req_num": self.args.running_max_req_size + 8, "max_seq_length": self.args.max_req_total_len + 8, # 留一点余量 "nccl_port": self.args.nccl_port, + "is_first_token_constraint_mode": self.args.first_token_constraint_mode, "is_splitfuse_mode": self.is_splitfuse_mode, "splitfuse_block_size": self.splitfuse_block_size, "is_token_healing": self.args.token_healing_mode, diff --git a/lightllm/server/router/model_infer/mode_backend/__init__.py b/lightllm/server/router/model_infer/mode_backend/__init__.py index f54a333c..d35c459e 100644 --- a/lightllm/server/router/model_infer/mode_backend/__init__.py +++ b/lightllm/server/router/model_infer/mode_backend/__init__.py @@ -6,3 +6,4 @@ from .diverse_backend.impl import DiversehBackend from .continues_batch.impl_for_token_healing import TokenHealingBackend from .continues_batch.impl_for_simple_constraint_mode import SimpleConstraintBackend +from .continues_batch.impl_for_first_token_constraint_mode import FirstTokenConstraintBackend diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_first_token_constraint_mode.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_first_token_constraint_mode.py new file mode 100644 index 00000000..8de33453 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_first_token_constraint_mode.py @@ -0,0 +1,72 @@ +import os +import shutil +import torch +from .impl import ContinuesBatchBackend +from lightllm.server.io_struct import FinishStatus +from lightllm.server.router.model_infer.infer_batch import InferBatch, InferReq, InferSamplingParams +from .pre_process import prepare_prefill_inputs, prepare_decode_inputs +from .post_process import sample +from typing import List +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class FirstTokenConstraintBackend(ContinuesBatchBackend): + def __init__(self) -> None: + super().__init__() + + def init_custom(self): + first_allowed_tokens_strs: str = os.environ.get("FIRST_ALLOWED_TOKENS", None) + logger.info(f"first_allowed_tokens_strs : {first_allowed_tokens_strs}") + # 使用该模式需要设置FIRST_ALLOWED_TOKENS 环境变量,格式为 "1,2" 或 "1,2,3" 等数字字符串 + assert first_allowed_tokens_strs is not None + first_allowed_tokens_strs.split(",") + self.first_allowed_tokens = [int(e.strip()) for e in first_allowed_tokens_strs.split(",") if len(e.strip()) > 0] + logger.info(f"first_allowed_tokens : {self.first_allowed_tokens}") + # check token_id < vocab_size + assert all(e < self.model.vocab_size for e in self.first_allowed_tokens) + return + + def forward(self, batch_id, is_prefill): + output_dict = {} + batch: InferBatch = self.cache.pop(batch_id) + if is_prefill: + kwargs, run_reqs = prepare_prefill_inputs(batch, self.radix_cache, self.is_multimodal) + else: + kwargs, run_reqs = prepare_decode_inputs(batch, self.radix_cache) + + logits = self.model.forward(**kwargs) + # first token constraint + if is_prefill: + mask = torch.ones_like(logits, dtype=torch.bool) + mask[:, self.first_allowed_tokens] = False + logits[mask] = -1000000.0 + + next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + + for req_obj, next_token_id, next_token_logprob in zip(run_reqs, next_token_ids, next_token_logprobs): + # prefill and decode is same + req_obj: InferReq = req_obj + req_obj.cur_kv_len = len(req_obj.input_token_ids) + req_obj.input_token_ids.append(next_token_id) + req_obj.out_token_id_count[next_token_id] += 1 + req_obj.update_finish_status(self.eos_id) + + metadata = { + "id": int(next_token_id), + "logprob": float(next_token_logprob), + } + output_dict[req_obj.r_id] = ( + req_obj.req_status, + req_obj.cur_kv_len, + req_obj.get_output_len(), + [(int(next_token_id), metadata)], + req_obj.finish_status.value, # 转化为整数,避免传送大对象, + None, + ) # 请求状态, 当前占用的kv的长度, 当前输出token的数量, 输出的token的id和元信息列表, 是否推理结束的状态, 额外保留参数 + + self.cache[batch.batch_id] = batch + return output_dict diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 2e56b931..2f6cc55c 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -12,6 +12,7 @@ RewardModelBackend, TokenHealingBackend, SimpleConstraintBackend, + FirstTokenConstraintBackend, ) from lightllm.utils.log_utils import init_logger @@ -31,6 +32,7 @@ def exposed_init_model(self, kvargs): beam_mode = kvargs.get("beam_mode", False) diverse_mode = kvargs.get("diverse_mode", False) is_token_healing = kvargs.get("is_token_healing", False) + is_first_token_constraint_mode = kvargs.get("is_first_token_constraint_mode", False) if kvargs.get("args", None) is not None: is_simple_constraint_mode = kvargs.get("args", None).simple_constraint_mode else: @@ -51,6 +53,8 @@ def exposed_init_model(self, kvargs): self.backend = TokenHealingBackend() elif is_simple_constraint_mode: self.backend = SimpleConstraintBackend() + elif is_first_token_constraint_mode: + self.backend = FirstTokenConstraintBackend() else: self.backend = ContinuesBatchBackend() diff --git a/lightllm/server/router/req_queue/__init__.py b/lightllm/server/router/req_queue/__init__.py index 87b301da..dd2e1e3a 100644 --- a/lightllm/server/router/req_queue/__init__.py +++ b/lightllm/server/router/req_queue/__init__.py @@ -14,5 +14,7 @@ def build_req_queue(args, router): return ContinuesBatchQueue(args, router) if args.simple_constraint_mode: return ContinuesBatchQueue(args, router) + if args.first_token_constraint_mode: + return ContinuesBatchQueue(args, router) return ContinuesBatchQueue(args, router)