Skip to content

Commit

Permalink
support to read eos_id from config.json (#569)
Browse files Browse the repository at this point in the history
Co-authored-by: wangzaijun <[email protected]>
  • Loading branch information
hiworldwzj and wangzaijun authored Oct 18, 2024
1 parent 71d1208 commit 7ca5b21
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 8 deletions.
24 changes: 16 additions & 8 deletions lightllm/server/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,9 @@ def make_argument_parser() -> argparse.ArgumentParser:
default=None,
help="max tokens num for new cat batch, it control prefill batch size to Preventing OOM",
)
parser.add_argument("--eos_id", nargs="+", type=int, default=[2], help="eos stop token id")
parser.add_argument(
"--eos_id", nargs="+", type=int, default=None, help="eos stop token id, if None, will load from config.json"
)
parser.add_argument(
"--running_max_req_size", type=int, default=1000, help="the max size for forward requests in the same time"
)
Expand Down Expand Up @@ -505,7 +507,7 @@ def main():
assert args.max_req_input_len < args.max_req_total_len
assert args.max_req_total_len <= args.max_total_token_num
assert not (args.beam_mode and args.use_dynamic_prompt_cache), "Beam mode incompatible with dynamic prompt cache"

# splitfuse_mode 和 cuda_graph 不能同时开启
if args.splitfuse_mode:
assert args.disable_cudagraph
Expand Down Expand Up @@ -537,6 +539,18 @@ def main():
batch_max_tokens = max(batch_max_tokens, args.splitfuse_block_size)
args.batch_max_tokens = batch_max_tokens

# help to manage data stored on Ceph
if "s3://" in args.model_dir:
from lightllm.utils.petrel_helper import s3_model_prepare

s3_model_prepare(args.model_dir)

# 如果args.eos_id 是 None, 从 config.json 中读取 eos_token_id 相关的信息,赋值给 args
if args.eos_id is None:
from lightllm.utils.config_utils import get_eos_token_ids

args.eos_id = get_eos_token_ids(args.model_dir)

logger.info(f"all start args:{args}")

can_use_ports = alloc_can_use_network_port(num=6 + args.tp, used_nccl_port=args.nccl_port)
Expand All @@ -560,12 +574,6 @@ def main():
global metric_client
metric_client = MetricClient(metric_port)

# help to manage data stored on Ceph
if "s3://" in args.model_dir:
from lightllm.utils.petrel_helper import s3_model_prepare

s3_model_prepare(args.model_dir)

global httpserver_manager
httpserver_manager = HttpServerManager(
args,
Expand Down
18 changes: 18 additions & 0 deletions lightllm/utils/config_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import json
import os


def get_config_json(model_path: str):
with open(os.path.join(model_path, "config.json"), "r") as file:
json_obj = json.load(file)
return json_obj


def get_eos_token_ids(model_path: str):
config_json = get_config_json(model_path)
eos_token_id = config_json["eos_token_id"]
if isinstance(eos_token_id, int):
return [eos_token_id]
if isinstance(eos_token_id, list):
return eos_token_id
assert False, "error eos_token_id format in config.json"

0 comments on commit 7ca5b21

Please sign in to comment.