Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dispatch support #510

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F
import numpy as np

from lightllm.models.llama.infer_struct import LlamaInferStateInfo
from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer
from lightllm.models.internlm2_dispatcher.layer_weights.pre_and_post_layer_weight import (
Internlm2DispatcherPreAndPostLayerWeight,
)
from einops import rearrange


class Internlm2DispatcherPostLayerInfer(LlamaPostLayerInfer):
def cls_forward(
self, last_input, infer_state: LlamaInferStateInfo, layer_weight: Internlm2DispatcherPreAndPostLayerWeight
):
cls0_out = F.gelu(torch.mm(layer_weight.cls0_weight_, last_input) + layer_weight.cls0_bias_[:, None])
cls1_out = F.gelu(torch.mm(layer_weight.cls2_weight_, cls0_out) + layer_weight.cls2_bias_[:, None])
cls2_out = torch.mm(layer_weight.cls4_weight_, cls1_out) + layer_weight.cls4_bias_[:, None]
return cls2_out

def token_forward(
self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: Internlm2DispatcherPreAndPostLayerWeight
):
last_input, token_num = self._slice_get_last_input(input_embdings, infer_state)
input_embdings_dtype = input_embdings.dtype
input_embdings = None
last_input = self._norm(last_input, infer_state, layer_weight)
last_input = rearrange(last_input, "batch embed_dim -> embed_dim batch").contiguous().reshape(-1, token_num)
logic_batch = torch.mm(layer_weight.lm_head_weight_, last_input)
cls_out = self.cls_forward(last_input, infer_state, layer_weight).permute(1, 0)
probs = torch.softmax(cls_out, dim=-1)[:, 1]

last_input = None
if self.world_size_ == 1:
gather_data = logic_batch
else:
gather_data = torch.empty(
(self.vocab_size_, token_num), device=logic_batch.device, dtype=input_embdings_dtype
)
split_indexes = np.linspace(0, self.vocab_size_, self.world_size_ + 1, dtype=np.int64)
dist.all_gather(
[gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.world_size_)],
logic_batch,
group=None,
async_op=False,
)
logic_batch = None

ans_logics = gather_data.permute(1, 0).float()
gather_data = None
return ans_logics, probs
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
import numpy as np
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight


class Internlm2DispatcherPreAndPostLayerWeight(LlamaPreAndPostLayerWeight):
def __init__(self, tp_rank, world_size, data_type, network_config, mode):
super().__init__(tp_rank, world_size, data_type, network_config, mode)
return

def load_hf_weights(self, weights):
vob_size = self.network_config_["vocab_size"]
split_indexes = np.linspace(0, vob_size, self.world_size_ + 1, dtype=np.int64)
split_start = split_indexes[self.tp_rank_]
split_end = split_indexes[self.tp_rank_ + 1]
if "model.tok_embeddings.weight" in weights:
self.wte_weight_ = self._cuda(weights["model.tok_embeddings.weight"][split_start:split_end, :])
if "output.weight" in weights:
self.lm_head_weight_ = self._cuda(weights["output.weight"][split_start:split_end, :])
if "model.norm.weight" in weights:
self.final_norm_weight_ = self._cuda(weights["model.norm.weight"])

# load classifiier weight
if "classifier.0.bias" in weights:
self.cls0_bias_ = self._cuda(weights["classifier.0.bias"])
if "classifier.0.weight" in weights:
self.cls0_weight_ = self._cuda(weights["classifier.0.weight"])
if "classifier.2.bias" in weights:
self.cls2_bias_ = self._cuda(weights["classifier.2.bias"])
if "classifier.2.weight" in weights:
self.cls2_weight_ = self._cuda(weights["classifier.2.weight"])
if "classifier.4.bias" in weights:
self.cls4_bias_ = self._cuda(weights["classifier.4.bias"])
if "classifier.4.weight" in weights:
self.cls4_weight_ = self._cuda(weights["classifier.4.weight"])

return
18 changes: 18 additions & 0 deletions lightllm/models/internlm2_dispatcher/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os
import json
import torch
from lightllm.models.internlm2_dispatcher.layer_weights.pre_and_post_layer_weight import (
Internlm2DispatcherPreAndPostLayerWeight,
)
from lightllm.models.internlm2.model import Internlm2TpPartModel
from lightllm.models.internlm2_dispatcher.layer_infer.post_layer_infer import Internlm2DispatcherPostLayerInfer


class Internlm2DispatcherTpPartModel(Internlm2TpPartModel):
# weight class
pre_and_post_weight_class = Internlm2DispatcherPreAndPostLayerWeight

post_layer_infer_class = Internlm2DispatcherPostLayerInfer

def __init__(self, kvargs):
super().__init__(kvargs)
12 changes: 10 additions & 2 deletions lightllm/server/api_lightllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def lightllm_get_score(request: Request, g_id_gen, httpserver_manager) ->
return Response(content=json.dumps(ret, ensure_ascii=False).encode("utf-8"))


async def lightllm_generate(request: Request, g_id_gen, httpserver_manager) -> Response:
async def lightllm_generate(request: Request, g_id_gen, httpserver_manager, use_id=False) -> Response:

request_dict = await request.json()
prompt = request_dict.pop("inputs")
Expand All @@ -45,13 +45,14 @@ async def lightllm_generate(request: Request, g_id_gen, httpserver_manager) -> R

group_request_id = g_id_gen.generate_id()
results_generator = httpserver_manager.generate(
prompt, sampling_params, group_request_id, multimodal_params, request=request
prompt, sampling_params, group_request_id, multimodal_params, request=request, use_id=use_id
)

# Non-streaming case
final_output_dict = collections.defaultdict(list)
count_output_tokens_dict = collections.defaultdict(lambda: 0)
tokens_dict = collections.defaultdict(list)
out_token_ids_dict = collections.defaultdict(list)
finish_reason_dict = {}
prompt_logprobs = None
prompt_tokens = 0
Expand All @@ -78,12 +79,17 @@ async def lightllm_generate(request: Request, g_id_gen, httpserver_manager) -> R

if finish_status.is_finished():
finish_reason_dict[sub_req_id] = finish_status

if use_id:
out_token_ids_dict[sub_req_id].append(metadata["out_token_id"])

n = sampling_params.n
sub_ids = list(final_output_dict.keys())[:n]
final_output_list = ["".join(final_output_dict[sub_id]) for sub_id in sub_ids]
count_output_tokens_list = [count_output_tokens_dict[sub_id] for sub_id in sub_ids]
finish_reson_list = [finish_reason_dict[sub_id].get_finish_reason() for sub_id in sub_ids]
tokens_list = [tokens_dict[sub_id] for sub_id in sub_ids]
out_token_ids_list = [out_token_ids_dict[sub_id] for sub_id in sub_ids]
only_one = len(sub_ids) == 1

ret_data_format = lambda data_list: data_list[0] if only_one else data_list
Expand All @@ -100,6 +106,8 @@ async def lightllm_generate(request: Request, g_id_gen, httpserver_manager) -> R
ret["prompt_token_ids"] = prompt_token_ids
if prompt_logprobs is not None:
ret["prompt_logprobs"] = prompt_logprobs
if use_id:
ret["out_token_ids"] = ret_data_format(out_token_ids_list)
return Response(content=json.dumps(ret, ensure_ascii=False).encode("utf-8"))


Expand Down
13 changes: 13 additions & 0 deletions lightllm/server/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,15 @@ async def get_score(request: Request) -> Response:
return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e))


@app.post("/id_generate")
async def id_generate(request: Request) -> Response:
first_set_handle_loop()
try:
return await g_generate_func(request, g_id_gen, httpserver_manager, use_id=True)
except Exception as e:
return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e))


@app.post("/")
async def compat_generate(request: Request) -> Response:
request_dict = await request.json()
Expand Down Expand Up @@ -452,6 +461,10 @@ def make_argument_parser() -> argparse.ArgumentParser:
"--enable_monitor_auth", action="store_true", help="Whether to open authentication for push_gateway"
)

parser.add_argument("--use_dispatcher_model", action="store_true")
parser.add_argument("--dispatch_threshold", type=float, default=0.8)
parser.add_argument("--dispatch_host", type=str, default="127.0.0.1")
parser.add_argument("--dispatch_port", type=int, default=12580)
return parser


Expand Down
9 changes: 6 additions & 3 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def tokens(self, prompt):
return len(prompt_ids)

async def generate(
self, prompt, sampling_params: SamplingParams, group_request_id, multimodal_params, request=None
self, prompt, sampling_params: SamplingParams, group_request_id, multimodal_params, request=None, use_id=False
):
# 记录请求到达的相关信息
if request is not None:
Expand All @@ -120,9 +120,9 @@ async def generate(
if self.enable_multimodal:
assert len(multimodal_params.images) <= self.args.cache_capacity, "too many images!"
await self._alloc_multimodal_resources(multimodal_params)
prompt_ids = self.tokenizer.encode(prompt, multimodal_params)
prompt_ids = prompt if use_id else self.tokenizer.encode(prompt, multimodal_params)
else:
prompt_ids = self.tokenizer.encode(prompt)
prompt_ids = prompt if use_id else self.tokenizer.encode(prompt)
prompt_tokens = len(prompt_ids)
# 监控
self.metric_client.histogram_observe("lightllm_request_input_length", prompt_tokens)
Expand Down Expand Up @@ -192,6 +192,9 @@ async def generate(
first_token_cost_ms = (time.time() - start_time) * 1000 if is_first_token else first_token_cost_ms
is_first_token = False

if use_id:
metadata["out_token_id"] = self.tokenizer.encode(out_str, add_special_tokens=False)[0]

yield sub_req_id, out_str, metadata, finish_status

# 如果有子请求完成,就更新计数
Expand Down
3 changes: 2 additions & 1 deletion lightllm/server/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ class FinishStatus(enum.Enum):
FINISHED_STOP = 1 # 因为遇到了STOP token 而结束
FINISHED_LENGTH = 2 # 因为长度达到了最大长度而结束
FINISHED_ABORT = 3 # 因为请求被中止而结束
FINISHED_DISPATCH = 4 # 因为请求被分发而结束

def is_finished(self):
return 1 <= self.value <= 3
return 1 <= self.value <= 4

def is_aborted(self):
return self == FinishStatus.FINISHED_ABORT
Expand Down
68 changes: 63 additions & 5 deletions lightllm/server/router/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import uvloop
import asyncio
import rpyc
import aiohttp
import json

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
import zmq
Expand Down Expand Up @@ -70,6 +72,13 @@ def __init__(self, args, router_port, detokenization_port, model_rpc_ports, metr

self.stats_tool = Stats(not args.disable_log_stats, args.log_stats_interval)
self.metric_client = MetricClient(metric_port)

self.use_dispatcher_model = args.use_dispatcher_model
if self.use_dispatcher_model:
self.dispatch_host = args.dispatch_host
self.dispatch_port = args.dispatch_port
self.need_dispatch_reqs_que = None
self.dispatch_info = [0, 0] # [dispatch_num, out_token_count]
return

async def wait_to_model_ready(self):
Expand Down Expand Up @@ -101,6 +110,8 @@ async def wait_to_model_ready(self):
"eos_id": self.eos_id,
"beam_mode": self.args.beam_mode,
"diverse_mode": self.args.diverse_mode,
"use_dispatcher_model": self.args.use_dispatcher_model,
"dispatch_threshold": self.args.dispatch_threshold,
}
init_model_ret.append(self.model_rpcs[rank_id].init_model(kvargs))

Expand Down Expand Up @@ -290,7 +301,7 @@ async def _prefill_batch(self, batch: Batch):

self._update_out_status_to_batch(batch, req_to_out_status)
unfinished_req_ids, finished_req_ids = batch.mark_and_get_finished_req_and_preupdate_status()
self._send_to_detokenization_proc(batch, req_to_out_status)
await self._send_to_detokenization_proc(batch, req_to_out_status)
batch.filter_out_finished_req(unfinished_req_ids, finished_req_ids)
await self._handle_finish_req(batch, unfinished_req_ids, finished_req_ids)
self.metric_client.histogram_observe(
Expand All @@ -310,7 +321,7 @@ async def _decode_batch(self, batch: Batch):

self._update_out_status_to_batch(batch, req_to_out_status)
unfinished_req_ids, finished_req_ids = batch.mark_and_get_finished_req_and_preupdate_status()
self._send_to_detokenization_proc(batch, req_to_out_status)
await self._send_to_detokenization_proc(batch, req_to_out_status)
batch.filter_out_finished_req(unfinished_req_ids, finished_req_ids)
await self._handle_finish_req(batch, unfinished_req_ids, finished_req_ids)
self.metric_client.histogram_observe(
Expand Down Expand Up @@ -392,10 +403,18 @@ def _update_out_status_to_batch(self, batch: Batch, req_to_out_status):
def _can_decode(self, batch: Batch):
return batch.batch_decode_need_tokens + self.get_used_tokens() <= self.max_total_token_num

def _send_to_detokenization_proc(self, batch: Batch, req_ans):
async def _send_to_detokenization_proc(self, batch: Batch, req_ans):
batch_out = BatchTokenIdOut()
for req_id, (_, _, _, token_info_list, _, _) in req_ans.items():
req = batch.id_to_reqs[req_id]
for req_id, (_, _, _, token_info_list, finish_status_value, _) in req_ans.items():
req: Req = batch.id_to_reqs[req_id]
if FinishStatus(finish_status_value) == FinishStatus.FINISHED_DISPATCH:
await self.need_dispatch_reqs_que.put(copy.deepcopy(req))
continue

# for logger
if self.use_dispatcher_model and FinishStatus(finish_status_value).is_finished():
self.dispatch_info[1] += req.cur_output_len

for idx, (new_token_id, new_gen_metadata) in enumerate(token_info_list):
# req.finish_status 传输 value值 不传送对象,可以减少序列化对象的大小。
if idx == len(token_info_list) - 1:
Expand Down Expand Up @@ -430,6 +449,42 @@ async def loop_for_netio_req(self):
else:
assert False, f"Error Req Inf {recv_req}"

async def loop_for_dispatch(self):
while True:
req = await self.need_dispatch_reqs_que.get()
self.dispatch_info[0] += 1
self.dispatch_info[1] += req.cur_output_len
logger.info(
f"dispatch prob: {self.dispatch_info[0]/self.dispatch_info[1] if self.dispatch_info[1] != 0 else 0}"
)
req.prompt_ids = req.prompt_ids[:-1]
_max_new_tokens = req.sample_params.max_new_tokens - req.cur_output_len
req.sample_params.max_new_tokens = 1
data = {
"inputs": [int(input_id) for input_id in req.prompt_ids],
"parameters": req.sample_params.to_dict(),
}
url = f"http://{self.dispatch_host}:{self.dispatch_port}/id_generate"

async with aiohttp.ClientSession() as session:
async with session.post(
url, headers={"Content-Type": "application/json"}, data=json.dumps(data)
) as response:
out_data = await response.text()
out_data = json.loads(out_data)
new_token_id = out_data["out_token_ids"][0]
req.prompt_ids.append(new_token_id)
req.sample_params.max_new_tokens = _max_new_tokens
new_req = NormalReq(
req.request_id, copy.deepcopy(req.prompt_ids), req.sample_params, req.multimodal_params
)
self.req_queue.back_to_wait_list([new_req])
batch_out = BatchTokenIdOut()
batch_out.reqs_infs.append(
(new_req.request_id, new_token_id, {"id": new_token_id}, FinishStatus.NO_FINISH)
)
self.send_to_detokenization.send_pyobj(batch_out)

def clean_up(self):
for model_rpc in self.model_rpcs:
model_rpc.rpc_server_process.kill()
Expand Down Expand Up @@ -469,5 +524,8 @@ def start_router_process(args, router_port, detokenization_port, model_rpc_ports
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.create_task(router.loop_for_fwd())
if args.use_dispatcher_model:
router.need_dispatch_reqs_que = asyncio.Queue()
loop.create_task(router.loop_for_dispatch())
loop.run_until_complete(router.loop_for_netio_req())
return
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .continues_batch.impl import ContinuesBatchBackend
from .continues_batch.impl_for_return_all_prompt_logprobs import ReturnPromptLogProbBackend
from .continues_batch.impl_for_reward_model import RewardModelBackend
from .continues_batch.impl_for_dispatcher_model import DispatcherModelBackend
from .splitfuse.impl import SplitFuseBackend
from .beamsearch.impl import BeamSearchBackend
from .diverse_backend.impl import DiversehBackend
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from lightllm.models.internlm.model import InternlmTpPartModel
from lightllm.models.stablelm.model import StablelmTpPartModel
from lightllm.models.internlm2.model import Internlm2TpPartModel
from lightllm.models.internlm2_dispatcher.model import Internlm2DispatcherTpPartModel
from lightllm.models.internlm2_reward.model import Internlm2RewardTpPartModel
from lightllm.models.internlm_wquant.model import InternlmTpPartModelWQuant
from lightllm.models.internlm2_wquant.model import Internlm2TpPartModelWQuant
Expand Down Expand Up @@ -170,6 +171,8 @@ def init_model(self, kvargs):
elif self.model_type == "internlm2":
if model_cfg["architectures"][0] == "InternLM2ForRewardModel":
self.model = Internlm2RewardTpPartModel(model_kvargs)
elif model_cfg["architectures"][0] == "MLP_Dispatcher":
self.model = Internlm2DispatcherTpPartModel(model_kvargs)
else:
self.model = Internlm2TpPartModel(model_kvargs)
elif self.model_type == "Yi":
Expand Down
Loading
Loading