-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
19 changed files
with
1,358 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
.PHONY: start | ||
start: | ||
uvicorn main:app --reload --host 0.0.0.0 --port 8123 | ||
|
||
.PHONY: format | ||
format: | ||
black . | ||
isort . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Retriever parameters | ||
SEMANTIC_RETRIEVER_PATH = "../models/api_semantic_retrieval.joblib" | ||
LEXICAL_RETRIEVER_PATH = "../models/api_lexical_retrieval.joblib" | ||
CROSS_ENCODER_PATH = "cross-encoder/ms-marco-MiniLM-L-6-v2" | ||
CROSS_ENCODER_THRESHOLD = 2.0 | ||
CROSS_ENCODER_MIN_TOP_K = 3 | ||
CROSS_ENCODER_MAX_TOP_K = 20 | ||
|
||
# Device parameters | ||
DEVICE = "mps" | ||
|
||
# LLM parameters | ||
LLM_PATH = "../models/mistral-7b-instruct-v0.1.Q6_K.gguf" | ||
TEMPERATURE = 0.1 | ||
TOP_P = 0.6 | ||
TOP_K = 40 | ||
REPETATION_PENALTY = 1.176 | ||
CONTEXT_TOKENS = 4096 | ||
MAX_RESPONSE_TOKENS = 8192 | ||
N_THREADS = 6 | ||
GPU_LAYERS = 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
"""Main entrypoint for the app.""" | ||
|
||
import json | ||
import logging | ||
import os | ||
import sys | ||
from importlib import import_module | ||
from pathlib import Path | ||
|
||
import joblib | ||
import resources as res | ||
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect | ||
from fastapi.responses import FileResponse | ||
from fastapi.staticfiles import StaticFiles | ||
from fastapi.templating import Jinja2Templates | ||
from llama_cpp import Llama | ||
from schemas import WSMessage | ||
from sentence_transformers import CrossEncoder | ||
|
||
sys.path.append(str(Path(__file__).parent.parent)) | ||
from rag_based_llm.prompt import QueryAgent | ||
from rag_based_llm.retrieval import RetrieverReranker | ||
|
||
DEFAULT_PORT = 8123 | ||
|
||
app = FastAPI() | ||
app.mount("/static", StaticFiles(directory="static"), name="static") | ||
templates = Jinja2Templates(directory="templates") | ||
|
||
logging.basicConfig() | ||
logger = logging.getLogger() | ||
logger.setLevel(logging.INFO) | ||
|
||
|
||
config_module = ( | ||
os.getenv("CONFIGURATION") if os.getenv("CONFIGURATION") is not None else "default" | ||
) | ||
logger.info(f"Configuration: {config_module} ") | ||
conf = import_module(f"configuration.{config_module}") | ||
|
||
|
||
async def send(ws, msg: str, type: str): | ||
message = WSMessage(sender="bot", message=msg, type=type) | ||
await ws.send_json(message.dict()) | ||
|
||
|
||
@app.on_event("startup") | ||
async def startup_event(): | ||
global agent | ||
|
||
api_semantic_retriever = joblib.load(conf.SEMANTIC_RETRIEVER_PATH) | ||
api_lexical_retriever = joblib.load(conf.LEXICAL_RETRIEVER_PATH) | ||
|
||
cross_encoder = CrossEncoder(model_name=conf.CROSS_ENCODER_PATH, device=conf.DEVICE) | ||
retriever_reranker = RetrieverReranker( | ||
cross_encoder=cross_encoder, | ||
semantic_retriever=api_semantic_retriever, | ||
lexical_retriever=api_lexical_retriever, | ||
threshold=conf.CROSS_ENCODER_THRESHOLD, | ||
min_top_k=conf.CROSS_ENCODER_MIN_TOP_K, | ||
max_top_k=conf.CROSS_ENCODER_MAX_TOP_K, | ||
) | ||
|
||
llm = Llama( | ||
model_path=conf.LLM_PATH, | ||
device=conf.DEVICE, | ||
n_gpu_layers=conf.GPU_LAYERS, | ||
n_threads=conf.N_THREADS, | ||
n_ctx=conf.CONTEXT_TOKENS, | ||
) | ||
agent = QueryAgent(llm=llm, retriever=retriever_reranker) | ||
logging.info("Server started") | ||
|
||
|
||
@app.get("/favicon.ico", include_in_schema=False) | ||
async def favicon(): | ||
return FileResponse("static/favicon.ico") | ||
|
||
|
||
@app.get("/") | ||
async def get(request: Request): | ||
return templates.TemplateResponse( | ||
"index.html", {"request": request, "res": res, "conf": conf} | ||
) | ||
|
||
|
||
@app.get("/inference.js") | ||
async def get(request: Request): # noqa: F811 | ||
return templates.TemplateResponse( | ||
"inference.js", | ||
{"request": request, "wsurl": os.getenv("WSURL", ""), "res": res, "conf": conf}, | ||
) | ||
|
||
|
||
@app.websocket("/inference") | ||
async def websocket_endpoint(websocket: WebSocket): | ||
await websocket.accept() | ||
await send(websocket, "Question the scikit-learn bot!", "info") | ||
|
||
while True: | ||
try: | ||
response_complete = "" | ||
start_type = "" | ||
|
||
received_text = await websocket.receive_text() | ||
payload = json.loads(received_text) | ||
|
||
prompt = payload["query"] | ||
start_type = "start" | ||
logger.info(f"Temperature: {payload['temperature']}") | ||
logger.info(f"Prompt: {prompt}") | ||
|
||
await send(websocket, "Analyzing prompt...", "info") | ||
stream, sources = agent( | ||
prompt, | ||
echo=False, | ||
stream=True, | ||
max_tokens=conf.MAX_RESPONSE_TOKENS, | ||
temperature=conf.TEMPERATURE, | ||
) | ||
for i in stream: | ||
response_text = i.get("choices", [])[0].get("text", "") | ||
answer_type = start_type if response_complete == "" else "stream" | ||
response_complete += response_text | ||
await send(websocket, response_text, answer_type) | ||
response_complete += "\n\nSource(s):\n" + ";\n".join(sources) | ||
await send(websocket, response_complete, start_type) | ||
|
||
await send(websocket, "", "end") | ||
except WebSocketDisconnect: | ||
logging.info("websocket disconnect") | ||
break | ||
except Exception as e: | ||
logging.error(e) | ||
await send(websocket, "Sorry, something went wrong. Try again.", "error") | ||
|
||
|
||
if __name__ == "__main__": | ||
import uvicorn | ||
|
||
uvicorn.run(app, host="0.0.0.0", port=DEFAULT_PORT) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
BUTTON_PROCESSING = "Thinking" | ||
BUTTON_TYPING = "Typing" | ||
PROMPT = "(Enter your prompt here, type !help for a list of commands)" | ||
BUTTON_SEND = "Ask the LLaMa" | ||
BUTTON_WAIT = "Wait" | ||
HELP = """<table style='margin-bottom:10px'> | ||
<tr><th colspan=2>Server commands</th></tr> | ||
<tr><td>!models</td><td>List available models</td></tr> | ||
<tr><td>!model</td><td>Show currently loaded model</td></tr> | ||
<tr><td>!model (filename)</td><td>Load a different model</td></tr> | ||
<tr><td>!stop</td><td>List of currenlty set stop words</td></tr> | ||
<tr><td>!stop ['word1',...] </td><td>Assign new stopwords</td></tr> | ||
<tr><td>!system</td><td>System State (used/free CPU and RAM)</td></tr> | ||
</table> | ||
<table> | ||
<tr><th colspan=2>Prompt templates (pres TAB to complete)</th></tr> | ||
<tr><td>#vic</td><td>Helpful AI Vicuna 1.1 prompt template</td></tr> | ||
<tr><td>#story </td><td>Storyteller Vicuna 1.1 prompt template</td></tr> | ||
<tr><td>###</td><td>Instruct/Response prompt template</td></tr> | ||
</table>""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
"""Schemas for the chat app.""" | ||
from pydantic import BaseModel, validator | ||
|
||
|
||
class WSMessage(BaseModel): | ||
"""Websocket Message schema.""" | ||
|
||
sender: str | ||
message: str | ||
type: str | ||
|
||
@validator("sender") | ||
def sender_must_be_bot_or_you(cls, v): | ||
if v not in ["bot", "you"]: | ||
raise ValueError("sender must be bot or you") | ||
return v | ||
|
||
@validator("type") | ||
def validate_message_type(cls, v): | ||
if v not in [ | ||
"question", | ||
"start", | ||
"restart", | ||
"stream", | ||
"end", | ||
"error", | ||
"info", | ||
"system", | ||
"done", | ||
]: | ||
raise ValueError("type must be start, stream or end") | ||
return v |
Large diffs are not rendered by default.
Oops, something went wrong.
Oops, something went wrong.