-
Notifications
You must be signed in to change notification settings - Fork 205
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into add_qwen2-vl_readme
- Loading branch information
Showing
15 changed files
with
491 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import sys | ||
import time | ||
|
||
sys.path.append("./") | ||
import argparse | ||
from datetime import timedelta | ||
from flask import Flask, render_template, request, session | ||
import uuid | ||
|
||
app = Flask(__name__) | ||
app.secret_key = "12312321" # 应该使用更加安全的密钥 | ||
app.config["PERMANENT_SESSION_LIFETIME"] = timedelta(minutes=10) | ||
global_chat_id_to_chat_obj = {} | ||
|
||
|
||
@app.route("/") | ||
def index(): | ||
session["chat_id"] = str(uuid.uuid4()) | ||
from qabot import QaBot | ||
|
||
global_chat_id_to_chat_obj[session["chat_id"]] = (QaBot(args.llm_url), time.time()) | ||
|
||
# 删除过期的qabot 对象 | ||
dels_keys = [] | ||
for key, (_, time_mark) in global_chat_id_to_chat_obj.items(): | ||
if time.time() - time_mark >= 10 * 60: | ||
dels_keys.append(key) | ||
|
||
for key in dels_keys: | ||
del global_chat_id_to_chat_obj[key] | ||
|
||
return render_template("chat.html") | ||
|
||
|
||
@app.route("/chat") | ||
def chat(): | ||
user_input = request.args.get("message", "") | ||
print("get", user_input) | ||
print("type", type(user_input)) | ||
qabot, _ = global_chat_id_to_chat_obj[session["chat_id"]] | ||
global_chat_id_to_chat_obj[session["chat_id"]] = (qabot, time.time()) | ||
return qabot.answer(user_input) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="demo") | ||
parser.add_argument("--llm_url", type=str, default="http://localhost:8017/generate", help="llm url") | ||
parser.add_argument("--port", type=int, default=8088, help="port") | ||
args = parser.parse_args() | ||
app.run(debug=True, port=args.port) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import os | ||
import sys | ||
import json | ||
from pydantic import BaseModel, constr, conlist | ||
from enum import Enum | ||
from typing import List | ||
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) | ||
|
||
from format_out.impl import ChatSession | ||
from format_out.impl import SamplingParams | ||
|
||
# server model is Meta-Llama-3-8B-Instruct | ||
system_prompt = """<|begin_of_text|><|start_header_id|>system<|end_header_id|> | ||
你是一个人工智能助手,只会做以下的事情: | ||
1. 知识问答 (根据已有的知识库信息,回答对应的问题) | ||
2. 查询占用端口的进程号 (通过生成指令,然后由系统执行返回结果,你做总结) | ||
3. 其他类型 (所有不是上面两种类型的问题) | ||
你在回答相关问题的过程中,需要按照指引,生成相关的json格式输出。 | ||
<|eot_id|>""" | ||
user_start = """<|start_header_id|>user<|end_header_id|>""" | ||
user_end = """<|eot_id|>""" | ||
assistant_start = """<|start_header_id|>assistant<|end_header_id|>""" | ||
assistant_end = """<|eot_id|>""" | ||
knowledge_start = """<|start_header_id|>knowledge<|end_header_id|>""" | ||
knowledge_end = """<|eot_id|>""" | ||
|
||
|
||
class QaBot: | ||
def __init__(self, llm_url="http://localhost:8017/generate"): | ||
chat_session = ChatSession(chat_his=system_prompt, url=llm_url, sampling_param=SamplingParams(do_sample=False)) | ||
chat_session.sampling_param.top_p = 0.7 | ||
chat_session.sampling_param.top_k = 12 | ||
chat_session.disable_log = True | ||
# 修改采样参数 | ||
chat_session.sampling_param.stop_sequences = [assistant_end, "<|end_of_text|>"] | ||
|
||
# 添加知识库 | ||
chat_session.add_prompt(knowledge_start) | ||
|
||
title = "lightllm的仓库链接" | ||
content = "https://github.com/ModelTC/lightllm" | ||
chat_session.add_prompt(f"<item><title>{title}</title><content>{content}</content></item>\n") | ||
title = "lightllm的文档链接" | ||
content = "https://github.com/ModelTC/lightllm/tree/main/docs" | ||
chat_session.add_prompt(f"<item><title>{title}</title><content>{content}</content></item>\n") | ||
title = "steam 账号" | ||
content = "account:12312321312, password:xxxxxxx" | ||
chat_session.add_prompt(f"<item><title>{title}</title><content>{content}</content></item>\n") | ||
title = "今天的天气" | ||
content = "天气很热" | ||
chat_session.add_prompt(f"<item><title>{title}</title><content>{content}</content></item>\n") | ||
|
||
chat_session.add_prompt(knowledge_end) | ||
|
||
self.chat_session = chat_session | ||
|
||
def answer(self, question_str: str): | ||
self.chat_session.add_prompt(user_start) | ||
self.chat_session.add_prompt(question_str) | ||
self.chat_session.add_prompt(user_end) | ||
|
||
self.chat_session.add_prompt(assistant_start) | ||
self.chat_session.add_prompt("先判断该问题的类型:") | ||
|
||
class QAType(Enum): | ||
Q1 = "知识问答" | ||
Q2 = "查询占用端口的进程号" | ||
Q3 = "其他类型" | ||
|
||
class ClassQuestion(BaseModel): | ||
thoughts: List[str] | ||
question_type: QAType | ||
|
||
json_ans_str = self.chat_session.gen_json_object(ClassQuestion, max_new_tokens=2048, prefix_regex=r"[\s]{0,20}") | ||
json_ans_str: str = json_ans_str.strip() | ||
json_ans_str = json_ans_str.replace("”", '"') # 修复 json 格式问题 | ||
print(json_ans_str) | ||
# 修复中文格式问题 | ||
json_ans_str = json_ans_str.encode("unicode_escape").decode() | ||
json_ans_str = json_ans_str.replace(r"\\u", r"\u") | ||
json_ans_str = json_ans_str.replace(r"\\\u", r"\u") | ||
json_ans_str = json_ans_str.replace(r"\\\\u", r"\u") | ||
json_ans_str = json_ans_str.encode("utf-8").decode("unicode_escape") | ||
print(json_ans_str) | ||
json_ans = json.loads(json_ans_str) | ||
formatted_json = json.dumps(json_ans, indent=4, ensure_ascii=False) | ||
print(formatted_json) | ||
self.chat_session.add_prompt(formatted_json + "\n") | ||
|
||
class_ans = ClassQuestion(**json_ans) | ||
if class_ans.question_type == QAType.Q3: | ||
ans_str = "对不起,我无法处理这个问题, 我只会下列问题:" "1. 知识问答 (根据已有的知识库信息,回答对应的问题)" "2. 查询占用端口的进程号 (通过生成指令,然后由系统执行返回结)" | ||
self.chat_session.add_prompt("给用户回答:" + ans_str) | ||
self.chat_session.add_prompt(assistant_end) | ||
return ans_str | ||
|
||
elif class_ans.question_type == QAType.Q1: | ||
return self.handle_qa() | ||
elif class_ans.question_type == QAType.Q2: | ||
return self.query_port_pid() | ||
|
||
def handle_qa(self): | ||
self.chat_session.add_prompt("通过知识库来会的这个问题:") | ||
|
||
class Result(BaseModel): | ||
finded_relevant_knowledge: conlist(str, min_length=0, max_length=10) | ||
can_answer: bool | ||
preliminary_results: str | ||
summary_result: constr(min_length=0, max_length=1000) | ||
|
||
json_ans_str = self.chat_session.gen_json_object(Result, max_new_tokens=2048, prefix_regex=r"[\s]{0,20}") | ||
json_ans_str: str = json_ans_str.strip() | ||
json_ans_str = json_ans_str.replace("”", '"') # 修复 json 格式问题 | ||
|
||
json_ans_str = json_ans_str.encode("unicode_escape").decode() | ||
json_ans_str = json_ans_str.replace(r"\\u", r"\u") | ||
json_ans_str = json_ans_str.replace(r"\\\u", r"\u") | ||
json_ans_str = json_ans_str.replace(r"\\\\u", r"\u") | ||
json_ans_str = json_ans_str.encode("utf-8").decode("unicode_escape") | ||
|
||
json_ans = json.loads(json_ans_str) | ||
formatted_json = json.dumps(json_ans, indent=4, ensure_ascii=False) | ||
print(formatted_json) | ||
self.chat_session.add_prompt(formatted_json + "\n") | ||
result = Result(**json_ans) | ||
if result.can_answer is False: | ||
ans_str = "对不起,我无法处理这个" | ||
self.chat_session.add_prompt("给用户回答:" + ans_str) | ||
self.chat_session.add_prompt(assistant_end) | ||
return ans_str | ||
else: | ||
ans_str = result.summary_result | ||
self.chat_session.add_prompt("给用户回答:" + ans_str) | ||
self.chat_session.add_prompt(assistant_end) | ||
return ans_str | ||
|
||
def query_port_pid(self): | ||
self.chat_session.add_prompt("收集需要用到的命令信息,如端口号:") | ||
|
||
class Result(BaseModel): | ||
thoughts: List[str] | ||
port_can_be_determined: bool | ||
port: str | ||
|
||
json_ans_str = self.chat_session.gen_json_object(Result, max_new_tokens=2048, prefix_regex=r"[\s]{0,20}") | ||
json_ans_str: str = json_ans_str.strip() | ||
json_ans_str = json_ans_str.replace("”", '"') # 修复 json 格式问题 | ||
|
||
json_ans_str = json_ans_str.encode("unicode_escape").decode() | ||
json_ans_str = json_ans_str.replace(r"\\u", r"\u") | ||
json_ans_str = json_ans_str.replace(r"\\\u", r"\u") | ||
json_ans_str = json_ans_str.replace(r"\\\\u", r"\u") | ||
json_ans_str = json_ans_str.encode("utf-8").decode("unicode_escape") | ||
|
||
json_ans = json.loads(json_ans_str) | ||
formatted_json = json.dumps(json_ans, indent=4, ensure_ascii=False) | ||
print(formatted_json) | ||
self.chat_session.add_prompt(formatted_json + "\n") | ||
result = Result(**json_ans) | ||
if result.port_can_be_determined is False: | ||
ans_str = "对不起,我无法确认端口号,请给出明确的端口号信息" | ||
self.chat_session.add_prompt("给用户回答:" + ans_str) | ||
self.chat_session.add_prompt(assistant_end) | ||
return ans_str | ||
else: | ||
import subprocess | ||
|
||
command = f"netstat -tunlp | grep {result.port}" | ||
result = subprocess.run(command, capture_output=True, text=True, shell=True) | ||
|
||
ans_str = str(result.stdout) + "\n" + str(result.stderr) | ||
self.chat_session.add_prompt("给用户回答:" + ans_str) | ||
self.chat_session.add_prompt(assistant_end) | ||
return ans_str |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
<!DOCTYPE html> | ||
<html lang="en"> | ||
<head> | ||
<meta charset="UTF-8"> | ||
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | ||
<title>两人多轮对话窗口</title> | ||
<style> | ||
body { | ||
font-family: 'Arial', sans-serif; | ||
background-color: #f0f0f0; | ||
margin: 0; | ||
padding: 0; | ||
display: flex; | ||
justify-content: center; | ||
align-items: center; | ||
height: 100vh; | ||
} | ||
.dialog-container { | ||
width: 1200px; | ||
background-color: #fff; | ||
border-radius: 8px; | ||
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); | ||
} | ||
.header { | ||
background-color: #007BFF; | ||
color: #fff; | ||
padding: 10px; | ||
border-top-left-radius: 8px; | ||
border-top-right-radius: 8px; | ||
} | ||
.header h3 { | ||
margin: 0; | ||
font-size: 18px; | ||
} | ||
.messages { | ||
padding: 10px; | ||
height: 600px; | ||
overflow-y: auto; | ||
} | ||
.message { | ||
margin-bottom: 10px; | ||
} | ||
.message p { | ||
display: inline-block; | ||
background-color: #e6e6e6; | ||
padding: 5px 10px; | ||
border-radius: 20px; | ||
max-width: 200px; | ||
overflow-wrap: break-word; | ||
} | ||
.user1 { | ||
text-align: left; | ||
} | ||
.user2 { | ||
text-align: right; | ||
} | ||
.user2 p { | ||
background-color: #007BFF; | ||
color: #fff; | ||
} | ||
.input-area { | ||
padding: 10px; | ||
border-top: 1px solid #e6e6e6; | ||
} | ||
.input-area input { | ||
width: 100%; | ||
border: none; | ||
background-color: #f0f0f0; | ||
padding: 10px; | ||
border-radius: 20px; | ||
} | ||
.input-area input:focus { | ||
outline: none; | ||
} | ||
</style> | ||
</head> | ||
<body> | ||
<div class="dialog-container"> | ||
<div class="header"> | ||
<h3>两人多轮对话</h3> | ||
</div> | ||
<div class="messages"> | ||
<!-- 对话内容将通过JavaScript添加到这里 --> | ||
</div> | ||
<div class="input-area"> | ||
<input type="text" id="messageInput" placeholder="输入你的消息"> | ||
</div> | ||
</div> | ||
<script> | ||
// 示例对话内容 | ||
const dialogs = [ | ||
{ user: '1', message: '你好!' }, | ||
{ user: '1', message: '你好,请问有什么可以帮助你的吗?' }, | ||
]; | ||
|
||
const messagesContainer = document.querySelector('.messages'); | ||
|
||
// 添加对话到页面 | ||
function addMessageToPage(user, message) { | ||
const messageDiv = document.createElement('div'); | ||
messageDiv.classList.add('message', `user${user}`); | ||
messageDiv.innerHTML = `<p>${message}</p>`; | ||
messagesContainer.appendChild(messageDiv); | ||
messagesContainer.scrollTop = messagesContainer.scrollHeight; | ||
} | ||
|
||
// 填充示例对话 | ||
dialogs.forEach(dialog => { | ||
addMessageToPage(dialog.user, dialog.message); | ||
}); | ||
|
||
// 监听消息输入 | ||
const messageInput = document.getElementById('messageInput'); | ||
messageInput.addEventListener('keypress', event => { | ||
if (event.key === 'Enter') { | ||
const user = '2'; | ||
const message = messageInput.value; | ||
addMessageToPage(user, message); | ||
messageInput.value = ''; | ||
|
||
// 使用Fetch API获取接口响应 | ||
fetch(`/chat?message=${encodeURIComponent(message)}`) | ||
.then(response => response.text()) | ||
.then(data => { | ||
// 显示响应 | ||
addMessageToPage('1', data); | ||
}).catch(error => { | ||
console.error('获取接口响应时出错:', error); | ||
addMessageToPage('1', '获取接口响应时出错'); | ||
}); | ||
|
||
messageInput.value = ''; | ||
} | ||
}); | ||
</script> | ||
</body> | ||
</html> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
一些应用demo的目录 |
Oops, something went wrong.