Skip to content

Commit

Permalink
format code
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyizi committed Nov 18, 2024
1 parent 3fc18a7 commit 478c36f
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 65 deletions.
82 changes: 41 additions & 41 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,41 +1,41 @@
# Please run command `pre-commit install` to install pre-commit hook
repos:
- repo: local
hooks:
- id: python-fmt
name: Python Format
entry: make fmt-check
language: system
exclude: '^dbgpt/app/static/|^web/'
types: [python]
stages: [commit]
pass_filenames: false
args: []
- id: python-test
name: Python Unit Test
entry: make test
language: system
exclude: '^dbgpt/app/static/|^web/'
types: [python]
stages: [commit]
pass_filenames: false
args: []
- id: python-test-doc
name: Python Doc Test
entry: make test-doc
language: system
exclude: '^dbgpt/app/static/|^web/'
types: [python]
stages: [commit]
pass_filenames: false
args: []
- id: python-lint-mypy
name: Python Lint mypy
entry: make mypy
language: system
exclude: '^dbgpt/app/static/|^web/'
types: [python]
stages: [commit]
pass_filenames: false
args: []

# Please run command `pre-commit install` to install pre-commit hook
repos:
- repo: local
hooks:
- id: python-fmt
name: Python Format
entry: make fmt-check
language: system
exclude: '^dbgpt/app/static/|^web/'
types: [python]
stages: [commit]
pass_filenames: false
args: []
- id: python-test
name: Python Unit Test
entry: make test
language: system
exclude: '^dbgpt/app/static/|^web/'
types: [python]
stages: [commit]
pass_filenames: false
args: []
- id: python-test-doc
name: Python Doc Test
entry: make test-doc
language: system
exclude: '^dbgpt/app/static/|^web/'
types: [python]
stages: [commit]
pass_filenames: false
args: []
- id: python-lint-mypy
name: Python Lint mypy
entry: make mypy
language: system
exclude: '^dbgpt/app/static/|^web/'
types: [python]
stages: [commit]
pass_filenames: false
args: []

5 changes: 3 additions & 2 deletions dbgpt/_private/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,11 @@ def __init__(self) -> None:
self.spark_proxy_api_password = os.getenv("XUNFEI_SPARK_API_PASSWORD")
self.spark_proxy_api_model = os.getenv("XUNFEI_SPARK_API_MODEL")
if self.spark_proxy_api_model and self.spark_proxy_api_password:
os.environ["spark_proxyllm_proxy_api_password"] = self.spark_proxy_api_password
os.environ[
"spark_proxyllm_proxy_api_password"
] = self.spark_proxy_api_password
os.environ["spark_proxyllm_proxy_api_model"] = self.spark_proxy_api_model


# baichuan proxy
self.bc_proxy_api_key = os.getenv("BAICHUAN_PROXY_API_KEY")
self.bc_model_name = os.getenv("BAICHUN_MODEL_NAME", "Baichuan2-Turbo-192k")
Expand Down
3 changes: 2 additions & 1 deletion dbgpt/app/scene/base_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ async def _build_model_request(self) -> ModelRequest:
)
node = AppChatComposerOperator(
model=self.llm_model,
temperature=self._chat_param.get("temperature") or float(self.prompt_template.temperature),
temperature=self._chat_param.get("temperature")
or float(self.prompt_template.temperature),
max_new_tokens=int(self.prompt_template.max_new_tokens),
prompt=self.prompt_template.prompt,
message_version=self._message_version,
Expand Down
39 changes: 18 additions & 21 deletions dbgpt/model/proxy/llms/spark.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
from concurrent.futures import Executor
from typing import Optional, AsyncIterator
from typing import AsyncIterator, Optional

from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.model.parameter import ProxyModelParameters
Expand Down Expand Up @@ -65,32 +65,30 @@ def get_response(request_url, data):
raise e
yield result


def extract_content(line: str):
if not line.strip():
return line
if line.startswith('data: '):
json_str = line[len('data: '):]
if line.startswith("data: "):
json_str = line[len("data: ") :]
else:
raise ValueError(
"Error line content "
)
raise ValueError("Error line content ")

try:
data = json.loads(json_str)
if data == '[DONE]':
return ''
if data == "[DONE]":
return ""

choices = data.get('choices', [])
choices = data.get("choices", [])
if choices and isinstance(choices, list):
delta = choices[0].get('delta', {})
content = delta.get('content', '')
delta = choices[0].get("delta", {})
content = delta.get("content", "")
return content
else:
raise ValueError(
"Error line content "
)
raise ValueError("Error line content ")
except json.JSONDecodeError:
return ''
return ""


class SparkLLMClient(ProxyLLMClient):
def __init__(
Expand Down Expand Up @@ -143,7 +141,6 @@ def new_client(
def default_model(self) -> str:
return self._model


def generate_stream(
self,
request: ModelRequest,
Expand All @@ -166,8 +163,8 @@ def generate_stream(
data = {
"model": self._model, # 指定请求的模型
"messages": messages,
"temperature" : request.temperature,
"stream": True
"temperature": request.temperature,
"stream": True,
}
header = {
"Authorization": f"Bearer {self._api_password}" # 注意此处替换自己的APIPassword
Expand All @@ -177,11 +174,11 @@ def generate_stream(
response.encoding = "utf-8"
try:
content = ""
#data: {"code":0,"message":"Success","sid":"cha000bf865@dx19307263c06b894532","id":"cha000bf865@dx19307263c06b894532","created":1730991766,"choices":[{"delta":{"role":"assistant","content":"你好"},"index":0}]}
#data: [DONE]
# data: {"code":0,"message":"Success","sid":"cha000bf865@dx19307263c06b894532","id":"cha000bf865@dx19307263c06b894532","created":1730991766,"choices":[{"delta":{"role":"assistant","content":"你好"},"index":0}]}
# data: [DONE]
for line in response.iter_lines(decode_unicode=True):
print("llm out:", line)
content=content + extract_content(line)
content = content + extract_content(line)
yield ModelOutput(text=content, error_code=0)
except Exception as e:
return ModelOutput(
Expand Down

0 comments on commit 478c36f

Please sign in to comment.