diff --git a/lightllm/models/internlm2_dispatcher/__init__.py b/lightllm/models/internlm2_dispatcher/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lightllm/models/internlm2_dispatcher/layer_infer/__init__.py b/lightllm/models/internlm2_dispatcher/layer_infer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lightllm/models/internlm2_dispatcher/layer_infer/post_layer_infer.py b/lightllm/models/internlm2_dispatcher/layer_infer/post_layer_infer.py new file mode 100644 index 00000000..ea0363b8 --- /dev/null +++ b/lightllm/models/internlm2_dispatcher/layer_infer/post_layer_infer.py @@ -0,0 +1,53 @@ +import torch +import torch.distributed as dist +import torch.nn.functional as F +import numpy as np + +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer +from lightllm.models.internlm2_dispatcher.layer_weights.pre_and_post_layer_weight import ( + Internlm2DispatcherPreAndPostLayerWeight, +) +from einops import rearrange + + +class Internlm2DispatcherPostLayerInfer(LlamaPostLayerInfer): + def cls_forward( + self, last_input, infer_state: LlamaInferStateInfo, layer_weight: Internlm2DispatcherPreAndPostLayerWeight + ): + cls0_out = F.gelu(torch.mm(layer_weight.cls0_weight_, last_input) + layer_weight.cls0_bias_[:, None]) + cls1_out = F.gelu(torch.mm(layer_weight.cls2_weight_, cls0_out) + layer_weight.cls2_bias_[:, None]) + cls2_out = torch.mm(layer_weight.cls4_weight_, cls1_out) + layer_weight.cls4_bias_[:, None] + return cls2_out + + def token_forward( + self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: Internlm2DispatcherPreAndPostLayerWeight + ): + last_input, token_num = self._slice_get_last_input(input_embdings, infer_state) + input_embdings_dtype = input_embdings.dtype + input_embdings = None + last_input = self._norm(last_input, infer_state, layer_weight) + last_input = rearrange(last_input, "batch embed_dim -> embed_dim batch").contiguous().reshape(-1, token_num) + logic_batch = torch.mm(layer_weight.lm_head_weight_, last_input) + cls_out = self.cls_forward(last_input, infer_state, layer_weight).permute(1, 0) + probs = torch.softmax(cls_out, dim=-1)[:, 1] + + last_input = None + if self.world_size_ == 1: + gather_data = logic_batch + else: + gather_data = torch.empty( + (self.vocab_size_, token_num), device=logic_batch.device, dtype=input_embdings_dtype + ) + split_indexes = np.linspace(0, self.vocab_size_, self.world_size_ + 1, dtype=np.int64) + dist.all_gather( + [gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.world_size_)], + logic_batch, + group=None, + async_op=False, + ) + logic_batch = None + + ans_logics = gather_data.permute(1, 0).float() + gather_data = None + return ans_logics, probs diff --git a/lightllm/models/internlm2_dispatcher/layer_weights/__init__.py b/lightllm/models/internlm2_dispatcher/layer_weights/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lightllm/models/internlm2_dispatcher/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internlm2_dispatcher/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 00000000..5d6b0078 --- /dev/null +++ b/lightllm/models/internlm2_dispatcher/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,37 @@ +import torch +import numpy as np +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight + + +class Internlm2DispatcherPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): + def __init__(self, tp_rank, world_size, data_type, network_config, mode): + super().__init__(tp_rank, world_size, data_type, network_config, mode) + return + + def load_hf_weights(self, weights): + vob_size = self.network_config_["vocab_size"] + split_indexes = np.linspace(0, vob_size, self.world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + if "model.tok_embeddings.weight" in weights: + self.wte_weight_ = self._cuda(weights["model.tok_embeddings.weight"][split_start:split_end, :]) + if "output.weight" in weights: + self.lm_head_weight_ = self._cuda(weights["output.weight"][split_start:split_end, :]) + if "model.norm.weight" in weights: + self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + + # load classifiier weight + if "classifier.0.bias" in weights: + self.cls0_bias_ = self._cuda(weights["classifier.0.bias"]) + if "classifier.0.weight" in weights: + self.cls0_weight_ = self._cuda(weights["classifier.0.weight"]) + if "classifier.2.bias" in weights: + self.cls2_bias_ = self._cuda(weights["classifier.2.bias"]) + if "classifier.2.weight" in weights: + self.cls2_weight_ = self._cuda(weights["classifier.2.weight"]) + if "classifier.4.bias" in weights: + self.cls4_bias_ = self._cuda(weights["classifier.4.bias"]) + if "classifier.4.weight" in weights: + self.cls4_weight_ = self._cuda(weights["classifier.4.weight"]) + + return diff --git a/lightllm/models/internlm2_dispatcher/model.py b/lightllm/models/internlm2_dispatcher/model.py new file mode 100644 index 00000000..d1862a02 --- /dev/null +++ b/lightllm/models/internlm2_dispatcher/model.py @@ -0,0 +1,18 @@ +import os +import json +import torch +from lightllm.models.internlm2_dispatcher.layer_weights.pre_and_post_layer_weight import ( + Internlm2DispatcherPreAndPostLayerWeight, +) +from lightllm.models.internlm2.model import Internlm2TpPartModel +from lightllm.models.internlm2_dispatcher.layer_infer.post_layer_infer import Internlm2DispatcherPostLayerInfer + + +class Internlm2DispatcherTpPartModel(Internlm2TpPartModel): + # weight class + pre_and_post_weight_class = Internlm2DispatcherPreAndPostLayerWeight + + post_layer_infer_class = Internlm2DispatcherPostLayerInfer + + def __init__(self, kvargs): + super().__init__(kvargs) diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index 4ed21711..7a0b5e54 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -31,7 +31,7 @@ async def lightllm_get_score(request: Request, g_id_gen, httpserver_manager) -> return Response(content=json.dumps(ret, ensure_ascii=False).encode("utf-8")) -async def lightllm_generate(request: Request, g_id_gen, httpserver_manager) -> Response: +async def lightllm_generate(request: Request, g_id_gen, httpserver_manager, use_id=False) -> Response: request_dict = await request.json() prompt = request_dict.pop("inputs") @@ -45,13 +45,14 @@ async def lightllm_generate(request: Request, g_id_gen, httpserver_manager) -> R group_request_id = g_id_gen.generate_id() results_generator = httpserver_manager.generate( - prompt, sampling_params, group_request_id, multimodal_params, request=request + prompt, sampling_params, group_request_id, multimodal_params, request=request, use_id=use_id ) # Non-streaming case final_output_dict = collections.defaultdict(list) count_output_tokens_dict = collections.defaultdict(lambda: 0) tokens_dict = collections.defaultdict(list) + out_token_ids_dict = collections.defaultdict(list) finish_reason_dict = {} prompt_logprobs = None prompt_tokens = 0 @@ -78,12 +79,17 @@ async def lightllm_generate(request: Request, g_id_gen, httpserver_manager) -> R if finish_status.is_finished(): finish_reason_dict[sub_req_id] = finish_status + + if use_id: + out_token_ids_dict[sub_req_id].append(metadata["out_token_id"]) + n = sampling_params.n sub_ids = list(final_output_dict.keys())[:n] final_output_list = ["".join(final_output_dict[sub_id]) for sub_id in sub_ids] count_output_tokens_list = [count_output_tokens_dict[sub_id] for sub_id in sub_ids] finish_reson_list = [finish_reason_dict[sub_id].get_finish_reason() for sub_id in sub_ids] tokens_list = [tokens_dict[sub_id] for sub_id in sub_ids] + out_token_ids_list = [out_token_ids_dict[sub_id] for sub_id in sub_ids] only_one = len(sub_ids) == 1 ret_data_format = lambda data_list: data_list[0] if only_one else data_list @@ -100,6 +106,8 @@ async def lightllm_generate(request: Request, g_id_gen, httpserver_manager) -> R ret["prompt_token_ids"] = prompt_token_ids if prompt_logprobs is not None: ret["prompt_logprobs"] = prompt_logprobs + if use_id: + ret["out_token_ids"] = ret_data_format(out_token_ids_list) return Response(content=json.dumps(ret, ensure_ascii=False).encode("utf-8")) diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index a1c6a126..40cae1bd 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -162,6 +162,15 @@ async def get_score(request: Request) -> Response: return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e)) +@app.post("/id_generate") +async def id_generate(request: Request) -> Response: + first_set_handle_loop() + try: + return await g_generate_func(request, g_id_gen, httpserver_manager, use_id=True) + except Exception as e: + return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e)) + + @app.post("/") async def compat_generate(request: Request) -> Response: request_dict = await request.json() @@ -452,6 +461,10 @@ def make_argument_parser() -> argparse.ArgumentParser: "--enable_monitor_auth", action="store_true", help="Whether to open authentication for push_gateway" ) + parser.add_argument("--use_dispatcher_model", action="store_true") + parser.add_argument("--dispatch_threshold", type=float, default=0.8) + parser.add_argument("--dispatch_host", type=str, default="127.0.0.1") + parser.add_argument("--dispatch_port", type=int, default=12580) return parser diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 09d3a102..c460ac49 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -95,7 +95,7 @@ def tokens(self, prompt): return len(prompt_ids) async def generate( - self, prompt, sampling_params: SamplingParams, group_request_id, multimodal_params, request=None + self, prompt, sampling_params: SamplingParams, group_request_id, multimodal_params, request=None, use_id=False ): # 记录请求到达的相关信息 if request is not None: @@ -120,9 +120,9 @@ async def generate( if self.enable_multimodal: assert len(multimodal_params.images) <= self.args.cache_capacity, "too many images!" await self._alloc_multimodal_resources(multimodal_params) - prompt_ids = self.tokenizer.encode(prompt, multimodal_params) + prompt_ids = prompt if use_id else self.tokenizer.encode(prompt, multimodal_params) else: - prompt_ids = self.tokenizer.encode(prompt) + prompt_ids = prompt if use_id else self.tokenizer.encode(prompt) prompt_tokens = len(prompt_ids) # 监控 self.metric_client.histogram_observe("lightllm_request_input_length", prompt_tokens) @@ -192,6 +192,9 @@ async def generate( first_token_cost_ms = (time.time() - start_time) * 1000 if is_first_token else first_token_cost_ms is_first_token = False + if use_id: + metadata["out_token_id"] = self.tokenizer.encode(out_str, add_special_tokens=False)[0] + yield sub_req_id, out_str, metadata, finish_status # 如果有子请求完成,就更新计数 diff --git a/lightllm/server/io_struct.py b/lightllm/server/io_struct.py index 92f2d01d..c14041c7 100644 --- a/lightllm/server/io_struct.py +++ b/lightllm/server/io_struct.py @@ -19,9 +19,10 @@ class FinishStatus(enum.Enum): FINISHED_STOP = 1 # 因为遇到了STOP token 而结束 FINISHED_LENGTH = 2 # 因为长度达到了最大长度而结束 FINISHED_ABORT = 3 # 因为请求被中止而结束 + FINISHED_DISPATCH = 4 # 因为请求被分发而结束 def is_finished(self): - return 1 <= self.value <= 3 + return 1 <= self.value <= 4 def is_aborted(self): return self == FinishStatus.FINISHED_ABORT diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index b1f87e69..91fb1ea0 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -4,6 +4,8 @@ import uvloop import asyncio import rpyc +import aiohttp +import json asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) import zmq @@ -70,6 +72,13 @@ def __init__(self, args, router_port, detokenization_port, model_rpc_ports, metr self.stats_tool = Stats(not args.disable_log_stats, args.log_stats_interval) self.metric_client = MetricClient(metric_port) + + self.use_dispatcher_model = args.use_dispatcher_model + if self.use_dispatcher_model: + self.dispatch_host = args.dispatch_host + self.dispatch_port = args.dispatch_port + self.need_dispatch_reqs_que = None + self.dispatch_info = [0, 0] # [dispatch_num, out_token_count] return async def wait_to_model_ready(self): @@ -101,6 +110,8 @@ async def wait_to_model_ready(self): "eos_id": self.eos_id, "beam_mode": self.args.beam_mode, "diverse_mode": self.args.diverse_mode, + "use_dispatcher_model": self.args.use_dispatcher_model, + "dispatch_threshold": self.args.dispatch_threshold, } init_model_ret.append(self.model_rpcs[rank_id].init_model(kvargs)) @@ -290,7 +301,7 @@ async def _prefill_batch(self, batch: Batch): self._update_out_status_to_batch(batch, req_to_out_status) unfinished_req_ids, finished_req_ids = batch.mark_and_get_finished_req_and_preupdate_status() - self._send_to_detokenization_proc(batch, req_to_out_status) + await self._send_to_detokenization_proc(batch, req_to_out_status) batch.filter_out_finished_req(unfinished_req_ids, finished_req_ids) await self._handle_finish_req(batch, unfinished_req_ids, finished_req_ids) self.metric_client.histogram_observe( @@ -310,7 +321,7 @@ async def _decode_batch(self, batch: Batch): self._update_out_status_to_batch(batch, req_to_out_status) unfinished_req_ids, finished_req_ids = batch.mark_and_get_finished_req_and_preupdate_status() - self._send_to_detokenization_proc(batch, req_to_out_status) + await self._send_to_detokenization_proc(batch, req_to_out_status) batch.filter_out_finished_req(unfinished_req_ids, finished_req_ids) await self._handle_finish_req(batch, unfinished_req_ids, finished_req_ids) self.metric_client.histogram_observe( @@ -392,10 +403,18 @@ def _update_out_status_to_batch(self, batch: Batch, req_to_out_status): def _can_decode(self, batch: Batch): return batch.batch_decode_need_tokens + self.get_used_tokens() <= self.max_total_token_num - def _send_to_detokenization_proc(self, batch: Batch, req_ans): + async def _send_to_detokenization_proc(self, batch: Batch, req_ans): batch_out = BatchTokenIdOut() - for req_id, (_, _, _, token_info_list, _, _) in req_ans.items(): - req = batch.id_to_reqs[req_id] + for req_id, (_, _, _, token_info_list, finish_status_value, _) in req_ans.items(): + req: Req = batch.id_to_reqs[req_id] + if FinishStatus(finish_status_value) == FinishStatus.FINISHED_DISPATCH: + await self.need_dispatch_reqs_que.put(copy.deepcopy(req)) + continue + + # for logger + if self.use_dispatcher_model and FinishStatus(finish_status_value).is_finished(): + self.dispatch_info[1] += req.cur_output_len + for idx, (new_token_id, new_gen_metadata) in enumerate(token_info_list): # req.finish_status 传输 value值 不传送对象,可以减少序列化对象的大小。 if idx == len(token_info_list) - 1: @@ -430,6 +449,42 @@ async def loop_for_netio_req(self): else: assert False, f"Error Req Inf {recv_req}" + async def loop_for_dispatch(self): + while True: + req = await self.need_dispatch_reqs_que.get() + self.dispatch_info[0] += 1 + self.dispatch_info[1] += req.cur_output_len + logger.info( + f"dispatch prob: {self.dispatch_info[0]/self.dispatch_info[1] if self.dispatch_info[1] != 0 else 0}" + ) + req.prompt_ids = req.prompt_ids[:-1] + _max_new_tokens = req.sample_params.max_new_tokens - req.cur_output_len + req.sample_params.max_new_tokens = 1 + data = { + "inputs": [int(input_id) for input_id in req.prompt_ids], + "parameters": req.sample_params.to_dict(), + } + url = f"http://{self.dispatch_host}:{self.dispatch_port}/id_generate" + + async with aiohttp.ClientSession() as session: + async with session.post( + url, headers={"Content-Type": "application/json"}, data=json.dumps(data) + ) as response: + out_data = await response.text() + out_data = json.loads(out_data) + new_token_id = out_data["out_token_ids"][0] + req.prompt_ids.append(new_token_id) + req.sample_params.max_new_tokens = _max_new_tokens + new_req = NormalReq( + req.request_id, copy.deepcopy(req.prompt_ids), req.sample_params, req.multimodal_params + ) + self.req_queue.back_to_wait_list([new_req]) + batch_out = BatchTokenIdOut() + batch_out.reqs_infs.append( + (new_req.request_id, new_token_id, {"id": new_token_id}, FinishStatus.NO_FINISH) + ) + self.send_to_detokenization.send_pyobj(batch_out) + def clean_up(self): for model_rpc in self.model_rpcs: model_rpc.rpc_server_process.kill() @@ -469,5 +524,8 @@ def start_router_process(args, router_port, detokenization_port, model_rpc_ports loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.create_task(router.loop_for_fwd()) + if args.use_dispatcher_model: + router.need_dispatch_reqs_que = asyncio.Queue() + loop.create_task(router.loop_for_dispatch()) loop.run_until_complete(router.loop_for_netio_req()) return diff --git a/lightllm/server/router/model_infer/mode_backend/__init__.py b/lightllm/server/router/model_infer/mode_backend/__init__.py index d0be3d78..47865901 100644 --- a/lightllm/server/router/model_infer/mode_backend/__init__.py +++ b/lightllm/server/router/model_infer/mode_backend/__init__.py @@ -1,6 +1,7 @@ from .continues_batch.impl import ContinuesBatchBackend from .continues_batch.impl_for_return_all_prompt_logprobs import ReturnPromptLogProbBackend from .continues_batch.impl_for_reward_model import RewardModelBackend +from .continues_batch.impl_for_dispatcher_model import DispatcherModelBackend from .splitfuse.impl import SplitFuseBackend from .beamsearch.impl import BeamSearchBackend from .diverse_backend.impl import DiversehBackend diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 902b8adf..90278ae8 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -30,6 +30,7 @@ from lightllm.models.internlm.model import InternlmTpPartModel from lightllm.models.stablelm.model import StablelmTpPartModel from lightllm.models.internlm2.model import Internlm2TpPartModel +from lightllm.models.internlm2_dispatcher.model import Internlm2DispatcherTpPartModel from lightllm.models.internlm2_reward.model import Internlm2RewardTpPartModel from lightllm.models.internlm_wquant.model import InternlmTpPartModelWQuant from lightllm.models.internlm2_wquant.model import Internlm2TpPartModelWQuant @@ -170,6 +171,8 @@ def init_model(self, kvargs): elif self.model_type == "internlm2": if model_cfg["architectures"][0] == "InternLM2ForRewardModel": self.model = Internlm2RewardTpPartModel(model_kvargs) + elif model_cfg["architectures"][0] == "MLP_Dispatcher": + self.model = Internlm2DispatcherTpPartModel(model_kvargs) else: self.model = Internlm2TpPartModel(model_kvargs) elif self.model_type == "Yi": diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_dispatcher_model.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_dispatcher_model.py new file mode 100644 index 00000000..648b80c5 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_dispatcher_model.py @@ -0,0 +1,58 @@ +import torch +from .impl import ContinuesBatchBackend +from lightllm.server.router.model_infer.infer_batch import InferBatch, InferReq, InferSamplingParams, requests_mapping +from .pre_process import prepare_prefill_inputs, prepare_decode_inputs +from lightllm.server.io_struct import ReqRunStatus, FinishStatus +from .post_process import sample + + +class DispatcherModelBackend(ContinuesBatchBackend): + def __init__(self, dispatch_threshold=0.5) -> None: + super().__init__() + self.dispatch_threshold = dispatch_threshold + + def forward(self, batch_id, is_prefill): + # special code for return all prompt_logprobs + output_dict = {} + batch: InferBatch = self.cache.pop(batch_id) + if is_prefill: + kwargs, run_reqs = prepare_prefill_inputs(batch, self.radix_cache, self.is_multimodal) + else: + kwargs, run_reqs = prepare_decode_inputs(batch, self.radix_cache) + + logits, dispatcher_probs = self.model.forward(**kwargs) + + next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + dispatcher_probs = dispatcher_probs.detach().cpu().numpy() + + for req_obj, next_token_id, next_token_logprob, dispatcher_prob in zip( + run_reqs, next_token_ids, next_token_logprobs, dispatcher_probs + ): + # prefill and decode is same + req_obj: InferReq = req_obj + req_obj.cur_kv_len = len(req_obj.input_token_ids) + req_obj.input_token_ids.append(next_token_id) + req_obj.out_token_id_count[next_token_id] += 1 + req_obj.update_finish_status(self.eos_id) + print(f"dispatcher_prob: {dispatcher_prob}") + if not req_obj.finish_status.is_finished() and dispatcher_prob > self.dispatch_threshold: + req_obj.finish_status = FinishStatus.FINISHED_DISPATCH + + metadata = { + "id": int(next_token_id), + "logprob": float(next_token_logprob), + "dispatcher_prob": float(dispatcher_prob), + } + output_dict[req_obj.r_id] = ( + req_obj.req_status, + req_obj.cur_kv_len, + req_obj.get_output_len(), + [(int(next_token_id), metadata)], + req_obj.finish_status.value, # 转化为整数,避免传送大对象, + None, + ) # 请求状态, 当前占用的kv的长度, 当前输出token的数量, 输出的token的id和元信息列表, 是否推理结束的状态, 额外保留参数 + + self.cache[batch.batch_id] = batch + return output_dict diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 68009c5c..8253c46b 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -10,6 +10,7 @@ BeamSearchBackend, DiversehBackend, RewardModelBackend, + DispatcherModelBackend, ) from lightllm.utils.log_utils import init_logger @@ -29,8 +30,12 @@ def exposed_init_model(self, kvargs): beam_mode = kvargs.get("beam_mode", False) diverse_mode = kvargs.get("diverse_mode", False) # use_dynamic_prompt_cache = kvargs.get("use_dynamic_prompt_cache", False) + use_dispatcher_model = kvargs.get("use_dispatcher_model", False) - if use_reward_model: + if use_dispatcher_model: + dispatch_threshold = kvargs.get("dispatch_threshold", 0.5) + self.backend = DispatcherModelBackend(dispatch_threshold) + elif use_reward_model: self.backend = RewardModelBackend() elif is_splitfuse_mode: self.backend = SplitFuseBackend()