-
Notifications
You must be signed in to change notification settings - Fork 205
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add first_token_constraint_mode (#599)
Co-authored-by: wangzaijun <[email protected]>
- Loading branch information
1 parent
e7184fc
commit e58aa74
Showing
6 changed files
with
87 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
72 changes: 72 additions & 0 deletions
72
...r/router/model_infer/mode_backend/continues_batch/impl_for_first_token_constraint_mode.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import os | ||
import shutil | ||
import torch | ||
from .impl import ContinuesBatchBackend | ||
from lightllm.server.io_struct import FinishStatus | ||
from lightllm.server.router.model_infer.infer_batch import InferBatch, InferReq, InferSamplingParams | ||
from .pre_process import prepare_prefill_inputs, prepare_decode_inputs | ||
from .post_process import sample | ||
from typing import List | ||
from lightllm.utils.log_utils import init_logger | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class FirstTokenConstraintBackend(ContinuesBatchBackend): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
def init_custom(self): | ||
first_allowed_tokens_strs: str = os.environ.get("FIRST_ALLOWED_TOKENS", None) | ||
logger.info(f"first_allowed_tokens_strs : {first_allowed_tokens_strs}") | ||
# 使用该模式需要设置FIRST_ALLOWED_TOKENS 环境变量,格式为 "1,2" 或 "1,2,3" 等数字字符串 | ||
assert first_allowed_tokens_strs is not None | ||
first_allowed_tokens_strs.split(",") | ||
self.first_allowed_tokens = [int(e.strip()) for e in first_allowed_tokens_strs.split(",") if len(e.strip()) > 0] | ||
logger.info(f"first_allowed_tokens : {self.first_allowed_tokens}") | ||
# check token_id < vocab_size | ||
assert all(e < self.model.vocab_size for e in self.first_allowed_tokens) | ||
return | ||
|
||
def forward(self, batch_id, is_prefill): | ||
output_dict = {} | ||
batch: InferBatch = self.cache.pop(batch_id) | ||
if is_prefill: | ||
kwargs, run_reqs = prepare_prefill_inputs(batch, self.radix_cache, self.is_multimodal) | ||
else: | ||
kwargs, run_reqs = prepare_decode_inputs(batch, self.radix_cache) | ||
|
||
logits = self.model.forward(**kwargs) | ||
# first token constraint | ||
if is_prefill: | ||
mask = torch.ones_like(logits, dtype=torch.bool) | ||
mask[:, self.first_allowed_tokens] = False | ||
logits[mask] = -1000000.0 | ||
|
||
next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) | ||
next_token_ids = next_token_ids.detach().cpu().numpy() | ||
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() | ||
|
||
for req_obj, next_token_id, next_token_logprob in zip(run_reqs, next_token_ids, next_token_logprobs): | ||
# prefill and decode is same | ||
req_obj: InferReq = req_obj | ||
req_obj.cur_kv_len = len(req_obj.input_token_ids) | ||
req_obj.input_token_ids.append(next_token_id) | ||
req_obj.out_token_id_count[next_token_id] += 1 | ||
req_obj.update_finish_status(self.eos_id) | ||
|
||
metadata = { | ||
"id": int(next_token_id), | ||
"logprob": float(next_token_logprob), | ||
} | ||
output_dict[req_obj.r_id] = ( | ||
req_obj.req_status, | ||
req_obj.cur_kv_len, | ||
req_obj.get_output_len(), | ||
[(int(next_token_id), metadata)], | ||
req_obj.finish_status.value, # 转化为整数,避免传送大对象, | ||
None, | ||
) # 请求状态, 当前占用的kv的长度, 当前输出token的数量, 输出的token的id和元信息列表, 是否推理结束的状态, 额外保留参数 | ||
|
||
self.cache[batch.batch_id] = batch | ||
return output_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters