Skip to content

Commit

Permalink
add first_token_constraint_mode (#599)
Browse files Browse the repository at this point in the history
Co-authored-by: wangzaijun <[email protected]>
  • Loading branch information
hiworldwzj and wangzaijun authored Oct 31, 2024
1 parent e7184fc commit e58aa74
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 1 deletion.
8 changes: 7 additions & 1 deletion lightllm/server/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode")
parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode")
parser.add_argument("--simple_constraint_mode", action="store_true", help="output constraint mode")

parser.add_argument(
"--first_token_constraint_mode",
action="store_true",
help="""constraint the first token allowed range,
use env FIRST_ALLOWED_TOKENS to set the range, like FIRST_ALLOWED_TOKENS=1,2 ..""",
)
parser.add_argument(
"--enable_multimodal", action="store_true", help="Whether or not to allow to load additional multimodal models."
)
Expand Down Expand Up @@ -546,6 +551,7 @@ def main():
args.token_healing_mode,
args.use_reward_model,
args.return_all_prompt_logprobs,
args.first_token_constraint_mode,
].count(True) <= 1
# 部分模式目前还无法与dynamic_prompt_cache一起跑,to do。
if args.use_dynamic_prompt_cache:
Expand Down
1 change: 1 addition & 0 deletions lightllm/server/router/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ async def wait_to_model_ready(self):
"max_req_num": self.args.running_max_req_size + 8,
"max_seq_length": self.args.max_req_total_len + 8, # 留一点余量
"nccl_port": self.args.nccl_port,
"is_first_token_constraint_mode": self.args.first_token_constraint_mode,
"is_splitfuse_mode": self.is_splitfuse_mode,
"splitfuse_block_size": self.splitfuse_block_size,
"is_token_healing": self.args.token_healing_mode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .diverse_backend.impl import DiversehBackend
from .continues_batch.impl_for_token_healing import TokenHealingBackend
from .continues_batch.impl_for_simple_constraint_mode import SimpleConstraintBackend
from .continues_batch.impl_for_first_token_constraint_mode import FirstTokenConstraintBackend
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os
import shutil
import torch
from .impl import ContinuesBatchBackend
from lightllm.server.io_struct import FinishStatus
from lightllm.server.router.model_infer.infer_batch import InferBatch, InferReq, InferSamplingParams
from .pre_process import prepare_prefill_inputs, prepare_decode_inputs
from .post_process import sample
from typing import List
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


class FirstTokenConstraintBackend(ContinuesBatchBackend):
def __init__(self) -> None:
super().__init__()

def init_custom(self):
first_allowed_tokens_strs: str = os.environ.get("FIRST_ALLOWED_TOKENS", None)
logger.info(f"first_allowed_tokens_strs : {first_allowed_tokens_strs}")
# 使用该模式需要设置FIRST_ALLOWED_TOKENS 环境变量,格式为 "1,2" 或 "1,2,3" 等数字字符串
assert first_allowed_tokens_strs is not None
first_allowed_tokens_strs.split(",")
self.first_allowed_tokens = [int(e.strip()) for e in first_allowed_tokens_strs.split(",") if len(e.strip()) > 0]
logger.info(f"first_allowed_tokens : {self.first_allowed_tokens}")
# check token_id < vocab_size
assert all(e < self.model.vocab_size for e in self.first_allowed_tokens)
return

def forward(self, batch_id, is_prefill):
output_dict = {}
batch: InferBatch = self.cache.pop(batch_id)
if is_prefill:
kwargs, run_reqs = prepare_prefill_inputs(batch, self.radix_cache, self.is_multimodal)
else:
kwargs, run_reqs = prepare_decode_inputs(batch, self.radix_cache)

logits = self.model.forward(**kwargs)
# first token constraint
if is_prefill:
mask = torch.ones_like(logits, dtype=torch.bool)
mask[:, self.first_allowed_tokens] = False
logits[mask] = -1000000.0

next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
next_token_ids = next_token_ids.detach().cpu().numpy()
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()

for req_obj, next_token_id, next_token_logprob in zip(run_reqs, next_token_ids, next_token_logprobs):
# prefill and decode is same
req_obj: InferReq = req_obj
req_obj.cur_kv_len = len(req_obj.input_token_ids)
req_obj.input_token_ids.append(next_token_id)
req_obj.out_token_id_count[next_token_id] += 1
req_obj.update_finish_status(self.eos_id)

metadata = {
"id": int(next_token_id),
"logprob": float(next_token_logprob),
}
output_dict[req_obj.r_id] = (
req_obj.req_status,
req_obj.cur_kv_len,
req_obj.get_output_len(),
[(int(next_token_id), metadata)],
req_obj.finish_status.value, # 转化为整数,避免传送大对象,
None,
) # 请求状态, 当前占用的kv的长度, 当前输出token的数量, 输出的token的id和元信息列表, 是否推理结束的状态, 额外保留参数

self.cache[batch.batch_id] = batch
return output_dict
4 changes: 4 additions & 0 deletions lightllm/server/router/model_infer/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
RewardModelBackend,
TokenHealingBackend,
SimpleConstraintBackend,
FirstTokenConstraintBackend,
)
from lightllm.utils.log_utils import init_logger

Expand All @@ -31,6 +32,7 @@ def exposed_init_model(self, kvargs):
beam_mode = kvargs.get("beam_mode", False)
diverse_mode = kvargs.get("diverse_mode", False)
is_token_healing = kvargs.get("is_token_healing", False)
is_first_token_constraint_mode = kvargs.get("is_first_token_constraint_mode", False)
if kvargs.get("args", None) is not None:
is_simple_constraint_mode = kvargs.get("args", None).simple_constraint_mode
else:
Expand All @@ -51,6 +53,8 @@ def exposed_init_model(self, kvargs):
self.backend = TokenHealingBackend()
elif is_simple_constraint_mode:
self.backend = SimpleConstraintBackend()
elif is_first_token_constraint_mode:
self.backend = FirstTokenConstraintBackend()
else:
self.backend = ContinuesBatchBackend()

Expand Down
2 changes: 2 additions & 0 deletions lightllm/server/router/req_queue/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,7 @@ def build_req_queue(args, router):
return ContinuesBatchQueue(args, router)
if args.simple_constraint_mode:
return ContinuesBatchQueue(args, router)
if args.first_token_constraint_mode:
return ContinuesBatchQueue(args, router)

return ContinuesBatchQueue(args, router)

0 comments on commit e58aa74

Please sign in to comment.