Skip to content

Commit

Permalink
add app
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre committed Dec 12, 2023
1 parent c44b09c commit 89a49f8
Show file tree
Hide file tree
Showing 19 changed files with 1,358 additions and 20 deletions.
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,25 @@

This is a all components to build a RAG-based LLM for the scikit-learn documentation.

## Installation

```bash
pip install -r requirements.txt
```

## Starting the server

```bash
cd app
make start
```

The server can be access to:

```bash
http://localhost:8123
```

## RAG-based LLM

We can represent a RAG-based LLM as follow [1]:
Expand Down
8 changes: 8 additions & 0 deletions app/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.PHONY: start
start:
uvicorn main:app --reload --host 0.0.0.0 --port 8123

.PHONY: format
format:
black .
isort .
21 changes: 21 additions & 0 deletions app/configuration/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Retriever parameters
SEMANTIC_RETRIEVER_PATH = "../models/api_semantic_retrieval.joblib"
LEXICAL_RETRIEVER_PATH = "../models/api_lexical_retrieval.joblib"
CROSS_ENCODER_PATH = "cross-encoder/ms-marco-MiniLM-L-6-v2"
CROSS_ENCODER_THRESHOLD = 2.0
CROSS_ENCODER_MIN_TOP_K = 3
CROSS_ENCODER_MAX_TOP_K = 20

# Device parameters
DEVICE = "mps"

# LLM parameters
LLM_PATH = "../models/mistral-7b-instruct-v0.1.Q6_K.gguf"
TEMPERATURE = 0.1
TOP_P = 0.6
TOP_K = 40
REPETATION_PENALTY = 1.176
CONTEXT_TOKENS = 4096
MAX_RESPONSE_TOKENS = 8192
N_THREADS = 6
GPU_LAYERS = 1
141 changes: 141 additions & 0 deletions app/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""Main entrypoint for the app."""

import json
import logging
import os
import sys
from importlib import import_module
from pathlib import Path

import joblib
import resources as res
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from llama_cpp import Llama
from schemas import WSMessage
from sentence_transformers import CrossEncoder

sys.path.append(str(Path(__file__).parent.parent))
from rag_based_llm.prompt import QueryAgent
from rag_based_llm.retrieval import RetrieverReranker

DEFAULT_PORT = 8123

app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")

logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)


config_module = (
os.getenv("CONFIGURATION") if os.getenv("CONFIGURATION") is not None else "default"
)
logger.info(f"Configuration: {config_module} ")
conf = import_module(f"configuration.{config_module}")


async def send(ws, msg: str, type: str):
message = WSMessage(sender="bot", message=msg, type=type)
await ws.send_json(message.dict())


@app.on_event("startup")
async def startup_event():
global agent

api_semantic_retriever = joblib.load(conf.SEMANTIC_RETRIEVER_PATH)
api_lexical_retriever = joblib.load(conf.LEXICAL_RETRIEVER_PATH)

cross_encoder = CrossEncoder(model_name=conf.CROSS_ENCODER_PATH, device=conf.DEVICE)
retriever_reranker = RetrieverReranker(
cross_encoder=cross_encoder,
semantic_retriever=api_semantic_retriever,
lexical_retriever=api_lexical_retriever,
threshold=conf.CROSS_ENCODER_THRESHOLD,
min_top_k=conf.CROSS_ENCODER_MIN_TOP_K,
max_top_k=conf.CROSS_ENCODER_MAX_TOP_K,
)

llm = Llama(
model_path=conf.LLM_PATH,
device=conf.DEVICE,
n_gpu_layers=conf.GPU_LAYERS,
n_threads=conf.N_THREADS,
n_ctx=conf.CONTEXT_TOKENS,
)
agent = QueryAgent(llm=llm, retriever=retriever_reranker)
logging.info("Server started")


@app.get("/favicon.ico", include_in_schema=False)
async def favicon():
return FileResponse("static/favicon.ico")


@app.get("/")
async def get(request: Request):
return templates.TemplateResponse(
"index.html", {"request": request, "res": res, "conf": conf}
)


@app.get("/inference.js")
async def get(request: Request): # noqa: F811
return templates.TemplateResponse(
"inference.js",
{"request": request, "wsurl": os.getenv("WSURL", ""), "res": res, "conf": conf},
)


@app.websocket("/inference")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
await send(websocket, "Question the scikit-learn bot!", "info")

while True:
try:
response_complete = ""
start_type = ""

received_text = await websocket.receive_text()
payload = json.loads(received_text)

prompt = payload["query"]
start_type = "start"
logger.info(f"Temperature: {payload['temperature']}")
logger.info(f"Prompt: {prompt}")

await send(websocket, "Analyzing prompt...", "info")
stream, sources = agent(
prompt,
echo=False,
stream=True,
max_tokens=conf.MAX_RESPONSE_TOKENS,
temperature=conf.TEMPERATURE,
)
for i in stream:
response_text = i.get("choices", [])[0].get("text", "")
answer_type = start_type if response_complete == "" else "stream"
response_complete += response_text
await send(websocket, response_text, answer_type)
response_complete += "\n\nSource(s):\n" + ";\n".join(sources)
await send(websocket, response_complete, start_type)

await send(websocket, "", "end")
except WebSocketDisconnect:
logging.info("websocket disconnect")
break
except Exception as e:
logging.error(e)
await send(websocket, "Sorry, something went wrong. Try again.", "error")


if __name__ == "__main__":
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=DEFAULT_PORT)
20 changes: 20 additions & 0 deletions app/resources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
BUTTON_PROCESSING = "Thinking"
BUTTON_TYPING = "Typing"
PROMPT = "(Enter your prompt here, type !help for a list of commands)"
BUTTON_SEND = "Ask the LLaMa"
BUTTON_WAIT = "Wait"
HELP = """<table style='margin-bottom:10px'>
<tr><th colspan=2>Server commands</th></tr>
<tr><td>!models</td><td>List available models</td></tr>
<tr><td>!model</td><td>Show currently loaded model</td></tr>
<tr><td>!model (filename)</td><td>Load a different model</td></tr>
<tr><td>!stop</td><td>List of currenlty set stop words</td></tr>
<tr><td>!stop ['word1',...]&nbsp;</td><td>Assign new stopwords</td></tr>
<tr><td>!system</td><td>System State (used/free CPU and RAM)</td></tr>
</table>
<table>
<tr><th colspan=2>Prompt templates (pres TAB to complete)</th></tr>
<tr><td>#vic</td><td>Helpful AI Vicuna 1.1 prompt template</td></tr>
<tr><td>#story&nbsp;</td><td>Storyteller Vicuna 1.1 prompt template</td></tr>
<tr><td>###</td><td>Instruct/Response prompt template</td></tr>
</table>"""
32 changes: 32 additions & 0 deletions app/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Schemas for the chat app."""
from pydantic import BaseModel, validator


class WSMessage(BaseModel):
"""Websocket Message schema."""

sender: str
message: str
type: str

@validator("sender")
def sender_must_be_bot_or_you(cls, v):
if v not in ["bot", "you"]:
raise ValueError("sender must be bot or you")
return v

@validator("type")
def validate_message_type(cls, v):
if v not in [
"question",
"start",
"restart",
"stream",
"end",
"error",
"info",
"system",
"done",
]:
raise ValueError("type must be start, stream or end")
return v
7 changes: 7 additions & 0 deletions app/static/css/bootstrap.min.css

Large diffs are not rendered by default.

Loading

0 comments on commit 89a49f8

Please sign in to comment.