Skip to content

Commit

Permalink
pixi multienv
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre committed Feb 4, 2024
1 parent 8924921 commit a4aebe7
Show file tree
Hide file tree
Showing 6 changed files with 2,956 additions and 11,307 deletions.
3 changes: 0 additions & 3 deletions app/configuration/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
CROSS_ENCODER_MIN_TOP_K = 3
CROSS_ENCODER_MAX_TOP_K = 20

# Device parameters
DEVICE = "mps"

# LLM parameters
LLM_PATH = "../models/mistral-7b-instruct-v0.1.Q6_K.gguf"
TEMPERATURE = 0.1
Expand Down
15 changes: 5 additions & 10 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
from ragger_duck.prompt import BasicPromptingStrategy
from ragger_duck.retrieval import RetrieverReranker

DEFAULT_PORT = 8123

app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
Expand All @@ -35,6 +33,8 @@
)
logging.info(f"Configuration: {config_module}")
conf = import_module(f"configuration.{config_module}")
DEVICE = os.getenv("DEVICE", "cpu")
logging.info(f"Device intended to be used: {DEVICE}")


async def send(ws, msg: str, type: str):
Expand All @@ -50,7 +50,7 @@ async def startup_event():
api_lexical_retriever = joblib.load(conf.API_LEXICAL_RETRIEVER_PATH)
user_guide_semantic_retriever = joblib.load(conf.API_SEMANTIC_RETRIEVER_PATH)
user_guide_lexical_retriever = joblib.load(conf.API_LEXICAL_RETRIEVER_PATH)
cross_encoder = CrossEncoder(model_name=conf.CROSS_ENCODER_PATH, device=conf.DEVICE)
cross_encoder = CrossEncoder(model_name=conf.CROSS_ENCODER_PATH, device=DEVICE)
retriever = RetrieverReranker(
retrievers=[
api_semantic_retriever.set_params(top_k=conf.API_SEMANTIC_TOP_K),
Expand All @@ -70,7 +70,7 @@ async def startup_event():

llm = Llama(
model_path=conf.LLM_PATH,
device=conf.DEVICE,
device=DEVICE,
n_gpu_layers=conf.GPU_LAYERS,
n_threads=conf.N_THREADS,
n_ctx=conf.CONTEXT_TOKENS,
Expand Down Expand Up @@ -116,6 +116,7 @@ async def websocket_endpoint(websocket: WebSocket):

prompt = payload["query"]
start_type = "start"
logging.info(f"Getting info from websocket: {payload}")

await send(websocket, "Analyzing prompt...", "info")
agent.set_params(
Expand Down Expand Up @@ -145,9 +146,3 @@ async def websocket_endpoint(websocket: WebSocket):
except Exception as e:
logging.error(e)
await send(websocket, "Sorry, something went wrong. Try again.", "error")


if __name__ == "__main__":
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=DEFAULT_PORT)
Loading

0 comments on commit a4aebe7

Please sign in to comment.