Skip to content

Commit

Permalink
feat:xunfei spark api use http instead of ws
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyizi committed Nov 7, 2024
1 parent cfce1ac commit 77f9387
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 127 deletions.
3 changes: 2 additions & 1 deletion .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ TONGYI_PROXY_API_KEY={your-tongyi-sk}
#BAICHUAN_PROXY_API_SECRET={your-baichuan-sct}

# Xunfei Spark
#XUNFEI_SPARK_API_VERSION={version}
#XUNFEI_SPARK_API_PASSWORD={your_api_password}
#XUNFEI_SPARK_API_MODEL={version}
#XUNFEI_SPARK_APPID={your_app_id}
#XUNFEI_SPARK_API_KEY={your_api_key}
#XUNFEI_SPARK_API_SECRET={your_api_secret}
Expand Down
196 changes: 70 additions & 126 deletions dbgpt/model/proxy/llms/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,14 @@
from concurrent.futures import Executor
from datetime import datetime
from time import mktime
from typing import Iterator, Optional
from typing import Iterator, Optional, AsyncIterator
from urllib.parse import urlencode, urlparse

from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel

SPARK_DEFAULT_API_VERSION = "v3"


def getlength(text):
length = 0
Expand Down Expand Up @@ -49,7 +47,7 @@ def spark_generate_stream(
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
for r in client.sync_generate_stream(request):
for r in client.generate_stream(request):
yield r


Expand All @@ -73,121 +71,60 @@ def get_response(request_url, data):
raise e
yield result


class SparkAPI:
def __init__(
self, appid: str, api_key: str, api_secret: str, spark_url: str
) -> None:
self.appid = appid
self.api_key = api_key
self.api_secret = api_secret
self.host = urlparse(spark_url).netloc
self.path = urlparse(spark_url).path

self.spark_url = spark_url

def gen_url(self):
from wsgiref.handlers import format_date_time

# 生成RFC1123格式的时间戳
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))

# 拼接字符串
signature_origin = "host: " + self.host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + self.path + " HTTP/1.1"

# 进行hmac-sha256进行加密
signature_sha = hmac.new(
self.api_secret.encode("utf-8"),
signature_origin.encode("utf-8"),
digestmod=hashlib.sha256,
).digest()

signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8")

authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'

authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(
encoding="utf-8"
def extract_content(line: str):
if not line.strip():
return line
if line.startswith('data: '):
json_str = line[len('data: '):]
else:
raise ValueError(
"Error line content "
)

# 将请求的鉴权参数组合为字典
v = {"authorization": authorization, "date": date, "host": self.host}
# 拼接鉴权参数,生成url
url = self.spark_url + "?" + urlencode(v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
return url

try:
data = json.loads(json_str)
if data == '[DONE]':
return ''

choices = data.get('choices', [])
if choices and isinstance(choices, list):
delta = choices[0].get('delta', {})
content = delta.get('content', '')
return content
else:
raise ValueError(
"Error line content "
)
except json.JSONDecodeError:
return ''

class SparkLLMClient(ProxyLLMClient):
def __init__(
self,
model: Optional[str] = None,
app_id: Optional[str] = None,
api_key: Optional[str] = None,
api_secret: Optional[str] = None,
api_base: Optional[str] = None,
api_domain: Optional[str] = None,
model_version: Optional[str] = None,
model_alias: Optional[str] = "spark_proxyllm",
context_length: Optional[int] = 4096,
executor: Optional[Executor] = None,
):
"""
Tips: 星火大模型API当前有Lite、Pro、Pro-128K、Max、Max-32K和4.0 Ultra六个版本,各版本独立计量tokens。
传输协议 :ws(s),为提高安全性,强烈推荐wss
Spark4.0 Ultra 请求地址,对应的domain参数为4.0Ultra:
wss://spark-api.xf-yun.com/v4.0/chat
星火大模型API当前有Lite、Pro、Pro-128K、Max、Max-32K和4.0 Ultra六个版本
Spark4.0 Ultra 请求地址,对应的domain参数为4.0Ultra
Spark Max-32K请求地址,对应的domain参数为max-32k
wss://spark-api.xf-yun.com/chat/max-32k
Spark Max请求地址,对应的domain参数为generalv3.5
wss://spark-api.xf-yun.com/v3.5/chat
Spark Pro-128K请求地址,对应的domain参数为pro-128k:
wss://spark-api.xf-yun.com/chat/pro-128k
Spark Pro请求地址,对应的domain参数为generalv3:
wss://spark-api.xf-yun.com/v3.1/chat
Spark Lite请求地址,对应的domain参数为lite:
wss://spark-api.xf-yun.com/v1.1/chat
https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html#_3-%E8%AF%B7%E6%B1%82%E8%AF%B4%E6%98%8E
"""
if not model_version:
model_version = model or os.getenv("XUNFEI_SPARK_API_VERSION")
if not api_base:
if model_version == SPARK_DEFAULT_API_VERSION:
api_base = "ws://spark-api.xf-yun.com/v3.1/chat"
domain = "generalv3"
elif model_version == "v4.0":
api_base = "ws://spark-api.xf-yun.com/v4.0/chat"
domain = "4.0Ultra"
elif model_version == "v3.5":
api_base = "ws://spark-api.xf-yun.com/v3.5/chat"
domain = "generalv3.5"
else:
api_base = "ws://spark-api.xf-yun.com/v1.1/chat"
domain = "lite"
if not api_domain:
api_domain = domain
self._model = model
self._model_version = model_version
self._api_base = api_base
self._domain = api_domain
self._app_id = app_id or os.getenv("XUNFEI_SPARK_APPID")
self._api_secret = api_secret or os.getenv("XUNFEI_SPARK_API_SECRET")
self._api_key = api_key or os.getenv("XUNFEI_SPARK_API_KEY")

if not self._app_id:
raise ValueError("app_id can't be empty")
if not self._api_key:
raise ValueError("api_key can't be empty")
if not self._api_secret:
raise ValueError("api_secret can't be empty")
self._model = model or os.getenv("XUNFEI_SPARK_API_MODEL")
self._api_base = os.getenv("PROXY_SERVER_URL")
self._api_password = os.getenv("XUNFEI_SPARK_API_PASSWORD")
if not self._model:
raise ValueError("model can't be empty")
if not self._api_base:
raise ValueError("api_base can't be empty")
if not self._api_password:
raise ValueError("api_password can't be empty")

super().__init__(
model_names=[model, model_alias],
Expand All @@ -203,10 +140,6 @@ def new_client(
) -> "SparkLLMClient":
return cls(
model=model_params.proxyllm_backend,
app_id=model_params.proxy_api_app_id,
api_key=model_params.proxy_api_key,
api_secret=model_params.proxy_api_secret,
api_base=model_params.proxy_api_base,
model_alias=model_params.model_name,
context_length=model_params.max_context_size,
executor=default_executor,
Expand All @@ -216,35 +149,46 @@ def new_client(
def default_model(self) -> str:
return self._model

def sync_generate_stream(

def generate_stream(
self,
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> Iterator[ModelOutput]:
) -> AsyncIterator[ModelOutput]:
"""
reference:
https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html#_3-%E8%AF%B7%E6%B1%82%E8%AF%B4%E6%98%8E
"""
request = self.local_covert_message(request, message_converter)
messages = request.to_common_messages(support_system_role=False)
request_id = request.context.request_id or "1"
try:
import requests
except ImportError as e:
raise ValueError(
"Could not import python package: requests "
"Please install requests by command `pip install requests"
) from e

data = {
"header": {"app_id": self._app_id, "uid": request_id},
"parameter": {
"chat": {
"domain": self._domain,
"random_threshold": 0.5,
"max_tokens": request.max_new_tokens,
"auditing": "default",
"temperature": request.temperature,
}
},
"payload": {"message": {"text": messages}},
"model": self._model, # 指定请求的模型
"messages": messages,
"temperature" : request.temperature,
"stream": True
}

spark_api = SparkAPI(
self._app_id, self._api_key, self._api_secret, self._api_base
)
request_url = spark_api.gen_url()
header = {
"Authorization": f"Bearer {self._api_password}" # 注意此处替换自己的APIPassword
}
response = requests.post(self._api_base, headers=header, json=data, stream=True)
# 流式响应解析示例
response.encoding = "utf-8"
try:
for text in get_response(request_url, data):
yield ModelOutput(text=text, error_code=0)
content = ""
#data: {"code":0,"message":"Success","sid":"cha000bf865@dx19307263c06b894532","id":"cha000bf865@dx19307263c06b894532","created":1730991766,"choices":[{"delta":{"role":"assistant","content":"你好"},"index":0}]}
#data: [DONE]
for line in response.iter_lines(decode_unicode=True):
print("llm out:", line)
content=content + extract_content(line)
yield ModelOutput(text=content, error_code=0)
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
Expand Down

0 comments on commit 77f9387

Please sign in to comment.