diff --git a/.env.template b/.env.template index 453e2c6d3..a4f03eea6 100644 --- a/.env.template +++ b/.env.template @@ -241,7 +241,8 @@ TONGYI_PROXY_API_KEY={your-tongyi-sk} #BAICHUAN_PROXY_API_SECRET={your-baichuan-sct} # Xunfei Spark -#XUNFEI_SPARK_API_VERSION={version} +#XUNFEI_SPARK_API_PASSWORD={your_api_password} +#XUNFEI_SPARK_API_MODEL={version} #XUNFEI_SPARK_APPID={your_app_id} #XUNFEI_SPARK_API_KEY={your_api_key} #XUNFEI_SPARK_API_SECRET={your_api_secret} diff --git a/dbgpt/model/proxy/llms/spark.py b/dbgpt/model/proxy/llms/spark.py index adbb647eb..641ab856a 100644 --- a/dbgpt/model/proxy/llms/spark.py +++ b/dbgpt/model/proxy/llms/spark.py @@ -6,7 +6,7 @@ from concurrent.futures import Executor from datetime import datetime from time import mktime -from typing import Iterator, Optional +from typing import Iterator, Optional, AsyncIterator from urllib.parse import urlencode, urlparse from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext @@ -14,8 +14,6 @@ from dbgpt.model.proxy.base import ProxyLLMClient from dbgpt.model.proxy.llms.proxy_model import ProxyModel -SPARK_DEFAULT_API_VERSION = "v3" - def getlength(text): length = 0 @@ -49,7 +47,7 @@ def spark_generate_stream( max_new_tokens=params.get("max_new_tokens"), stop=params.get("stop"), ) - for r in client.sync_generate_stream(request): + for r in client.generate_stream(request): yield r @@ -73,121 +71,60 @@ def get_response(request_url, data): raise e yield result - -class SparkAPI: - def __init__( - self, appid: str, api_key: str, api_secret: str, spark_url: str - ) -> None: - self.appid = appid - self.api_key = api_key - self.api_secret = api_secret - self.host = urlparse(spark_url).netloc - self.path = urlparse(spark_url).path - - self.spark_url = spark_url - - def gen_url(self): - from wsgiref.handlers import format_date_time - - # 生成RFC1123格式的时间戳 - now = datetime.now() - date = format_date_time(mktime(now.timetuple())) - - # 拼接字符串 - signature_origin = "host: " + self.host + "\n" - signature_origin += "date: " + date + "\n" - signature_origin += "GET " + self.path + " HTTP/1.1" - - # 进行hmac-sha256进行加密 - signature_sha = hmac.new( - self.api_secret.encode("utf-8"), - signature_origin.encode("utf-8"), - digestmod=hashlib.sha256, - ).digest() - - signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8") - - authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' - - authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode( - encoding="utf-8" +def extract_content(line: str): + if not line.strip(): + return line + if line.startswith('data: '): + json_str = line[len('data: '):] + else: + raise ValueError( + "Error line content " ) - # 将请求的鉴权参数组合为字典 - v = {"authorization": authorization, "date": date, "host": self.host} - # 拼接鉴权参数,生成url - url = self.spark_url + "?" + urlencode(v) - # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 - return url - + try: + data = json.loads(json_str) + if data == '[DONE]': + return '' + + choices = data.get('choices', []) + if choices and isinstance(choices, list): + delta = choices[0].get('delta', {}) + content = delta.get('content', '') + return content + else: + raise ValueError( + "Error line content " + ) + except json.JSONDecodeError: + return '' class SparkLLMClient(ProxyLLMClient): def __init__( self, model: Optional[str] = None, - app_id: Optional[str] = None, - api_key: Optional[str] = None, - api_secret: Optional[str] = None, - api_base: Optional[str] = None, - api_domain: Optional[str] = None, - model_version: Optional[str] = None, model_alias: Optional[str] = "spark_proxyllm", context_length: Optional[int] = 4096, executor: Optional[Executor] = None, ): """ - Tips: 星火大模型API当前有Lite、Pro、Pro-128K、Max、Max-32K和4.0 Ultra六个版本,各版本独立计量tokens。 - 传输协议 :ws(s),为提高安全性,强烈推荐wss - - Spark4.0 Ultra 请求地址,对应的domain参数为4.0Ultra: - wss://spark-api.xf-yun.com/v4.0/chat - + 星火大模型API当前有Lite、Pro、Pro-128K、Max、Max-32K和4.0 Ultra六个版本 + Spark4.0 Ultra 请求地址,对应的domain参数为4.0Ultra Spark Max-32K请求地址,对应的domain参数为max-32k - wss://spark-api.xf-yun.com/chat/max-32k - Spark Max请求地址,对应的domain参数为generalv3.5 - wss://spark-api.xf-yun.com/v3.5/chat - Spark Pro-128K请求地址,对应的domain参数为pro-128k: - wss://spark-api.xf-yun.com/chat/pro-128k - Spark Pro请求地址,对应的domain参数为generalv3: - wss://spark-api.xf-yun.com/v3.1/chat - Spark Lite请求地址,对应的domain参数为lite: - wss://spark-api.xf-yun.com/v1.1/chat + https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html#_3-%E8%AF%B7%E6%B1%82%E8%AF%B4%E6%98%8E """ - if not model_version: - model_version = model or os.getenv("XUNFEI_SPARK_API_VERSION") - if not api_base: - if model_version == SPARK_DEFAULT_API_VERSION: - api_base = "ws://spark-api.xf-yun.com/v3.1/chat" - domain = "generalv3" - elif model_version == "v4.0": - api_base = "ws://spark-api.xf-yun.com/v4.0/chat" - domain = "4.0Ultra" - elif model_version == "v3.5": - api_base = "ws://spark-api.xf-yun.com/v3.5/chat" - domain = "generalv3.5" - else: - api_base = "ws://spark-api.xf-yun.com/v1.1/chat" - domain = "lite" - if not api_domain: - api_domain = domain - self._model = model - self._model_version = model_version - self._api_base = api_base - self._domain = api_domain - self._app_id = app_id or os.getenv("XUNFEI_SPARK_APPID") - self._api_secret = api_secret or os.getenv("XUNFEI_SPARK_API_SECRET") - self._api_key = api_key or os.getenv("XUNFEI_SPARK_API_KEY") - - if not self._app_id: - raise ValueError("app_id can't be empty") - if not self._api_key: - raise ValueError("api_key can't be empty") - if not self._api_secret: - raise ValueError("api_secret can't be empty") + self._model = model or os.getenv("XUNFEI_SPARK_API_MODEL") + self._api_base = os.getenv("PROXY_SERVER_URL") + self._api_password = os.getenv("XUNFEI_SPARK_API_PASSWORD") + if not self._model: + raise ValueError("model can't be empty") + if not self._api_base: + raise ValueError("api_base can't be empty") + if not self._api_password: + raise ValueError("api_password can't be empty") super().__init__( model_names=[model, model_alias], @@ -203,10 +140,6 @@ def new_client( ) -> "SparkLLMClient": return cls( model=model_params.proxyllm_backend, - app_id=model_params.proxy_api_app_id, - api_key=model_params.proxy_api_key, - api_secret=model_params.proxy_api_secret, - api_base=model_params.proxy_api_base, model_alias=model_params.model_name, context_length=model_params.max_context_size, executor=default_executor, @@ -216,35 +149,46 @@ def new_client( def default_model(self) -> str: return self._model - def sync_generate_stream( + + def generate_stream( self, request: ModelRequest, message_converter: Optional[MessageConverter] = None, - ) -> Iterator[ModelOutput]: + ) -> AsyncIterator[ModelOutput]: + """ + reference: + https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html#_3-%E8%AF%B7%E6%B1%82%E8%AF%B4%E6%98%8E + """ request = self.local_covert_message(request, message_converter) messages = request.to_common_messages(support_system_role=False) - request_id = request.context.request_id or "1" + try: + import requests + except ImportError as e: + raise ValueError( + "Could not import python package: requests " + "Please install requests by command `pip install requests" + ) from e + data = { - "header": {"app_id": self._app_id, "uid": request_id}, - "parameter": { - "chat": { - "domain": self._domain, - "random_threshold": 0.5, - "max_tokens": request.max_new_tokens, - "auditing": "default", - "temperature": request.temperature, - } - }, - "payload": {"message": {"text": messages}}, + "model": self._model, # 指定请求的模型 + "messages": messages, + "temperature" : request.temperature, + "stream": True } - - spark_api = SparkAPI( - self._app_id, self._api_key, self._api_secret, self._api_base - ) - request_url = spark_api.gen_url() + header = { + "Authorization": f"Bearer {self._api_password}" # 注意此处替换自己的APIPassword + } + response = requests.post(self._api_base, headers=header, json=data, stream=True) + # 流式响应解析示例 + response.encoding = "utf-8" try: - for text in get_response(request_url, data): - yield ModelOutput(text=text, error_code=0) + content = "" + #data: {"code":0,"message":"Success","sid":"cha000bf865@dx19307263c06b894532","id":"cha000bf865@dx19307263c06b894532","created":1730991766,"choices":[{"delta":{"role":"assistant","content":"你好"},"index":0}]} + #data: [DONE] + for line in response.iter_lines(decode_unicode=True): + print("llm out:", line) + content=content + extract_content(line) + yield ModelOutput(text=content, error_code=0) except Exception as e: return ModelOutput( text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",