Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support to read eos_id from config.json #569

Merged
merged 1 commit into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions lightllm/server/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,9 @@ def make_argument_parser() -> argparse.ArgumentParser:
default=None,
help="max tokens num for new cat batch, it control prefill batch size to Preventing OOM",
)
parser.add_argument("--eos_id", nargs="+", type=int, default=[2], help="eos stop token id")
parser.add_argument(
"--eos_id", nargs="+", type=int, default=None, help="eos stop token id, if None, will load from config.json"
)
parser.add_argument(
"--running_max_req_size", type=int, default=1000, help="the max size for forward requests in the same time"
)
Expand Down Expand Up @@ -505,7 +507,7 @@ def main():
assert args.max_req_input_len < args.max_req_total_len
assert args.max_req_total_len <= args.max_total_token_num
assert not (args.beam_mode and args.use_dynamic_prompt_cache), "Beam mode incompatible with dynamic prompt cache"

# splitfuse_mode 和 cuda_graph 不能同时开启
if args.splitfuse_mode:
assert args.disable_cudagraph
Expand Down Expand Up @@ -537,6 +539,18 @@ def main():
batch_max_tokens = max(batch_max_tokens, args.splitfuse_block_size)
args.batch_max_tokens = batch_max_tokens

# help to manage data stored on Ceph
if "s3://" in args.model_dir:
from lightllm.utils.petrel_helper import s3_model_prepare

s3_model_prepare(args.model_dir)

# 如果args.eos_id 是 None, 从 config.json 中读取 eos_token_id 相关的信息,赋值给 args
if args.eos_id is None:
from lightllm.utils.config_utils import get_eos_token_ids

args.eos_id = get_eos_token_ids(args.model_dir)

logger.info(f"all start args:{args}")

can_use_ports = alloc_can_use_network_port(num=6 + args.tp, used_nccl_port=args.nccl_port)
Expand All @@ -560,12 +574,6 @@ def main():
global metric_client
metric_client = MetricClient(metric_port)

# help to manage data stored on Ceph
if "s3://" in args.model_dir:
from lightllm.utils.petrel_helper import s3_model_prepare

s3_model_prepare(args.model_dir)

global httpserver_manager
httpserver_manager = HttpServerManager(
args,
Expand Down
18 changes: 18 additions & 0 deletions lightllm/utils/config_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import json
import os


def get_config_json(model_path: str):
with open(os.path.join(model_path, "config.json"), "r") as file:
json_obj = json.load(file)
return json_obj


def get_eos_token_ids(model_path: str):
config_json = get_config_json(model_path)
eos_token_id = config_json["eos_token_id"]
if isinstance(eos_token_id, int):
return [eos_token_id]
if isinstance(eos_token_id, list):
return eos_token_id
assert False, "error eos_token_id format in config.json"
Loading