Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache embeddings and completions. #17

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion athena-app/src/App.js
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ function App() {
body: JSON.stringify({ input: userInput }),
});
const data = await response.json();
const athenaMessage = { text: data.response, isAthena: true };
const json = JSON.parse(data.response).choices[0];
const athenaMessage = { text: json.text ? json.text : json.message.content, isAthena: true };
setMessages((prevMessages) => [...prevMessages, athenaMessage]);
setUserInput('');
};
Expand Down
5 changes: 1 addition & 4 deletions athena/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,7 @@ def main(log_level) -> None:
input_extension = InputPipelineExtension(app)
app.extensions["input_pipeline"] = input_extension
app.register_blueprint(api, url_prefix="/api/v1")
app.run(
host="0.0.0.0", port=5000,
debug=log_level == "DEBUG"
)
app.run(host="0.0.0.0", port=5000, debug=log_level == "DEBUG")


if __name__ == "__main__":
Expand Down
49 changes: 49 additions & 0 deletions athena/api_views/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Any, Dict, Tuple, Union

from flask import current_app, request
from flask_restx import Namespace, Resource, fields
from loguru import logger

from athena.input_processor import process_input

llm_api = Namespace("llm", description="llm related operations")

openai_completion = llm_api.model(
"OpenAICompletion",
{
"response": fields.String(required=True, description="The response"),
"error": fields.String(required=True, description="The error"),
},
)


@llm_api.route("/openai/completion")
class OpenAICompletion(Resource):
@llm_api.doc("openai_completion")
@llm_api.marshal_list_with(openai_completion)
def post(self) -> Union[Dict[str, Any], Tuple[Dict[str, Any], int]]:
"""API endpoint to receive user input and return the response from Athena.

Returns:
A JSON response with the response message or an error message.
"""
logger.debug("Received request to chat with Athena")
data = request.get_json(force=True)
user_input = data.get("input")
username = data.get("username", None)
logger.debug(f"User input: {user_input} Username: {username}")
input_pipeline_extension = current_app.extensions["input_pipeline"]
try:
if user_input:
response = process_input(
input_pipeline_extension.intent_pipeline,
input_pipeline_extension.entity_pipeline,
user_input,
username,
)
return {"response": response, "error": None}
else:
return {"response": None, "error": "Missing input"}, 400
except Exception as e:
logger.exception(f"Error in processing user input: {e}")
return {"response": None, "error": "Error in processing user input"}, 500
1 change: 0 additions & 1 deletion athena/celery.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os

from celery import Celery

from dotenv import load_dotenv

load_dotenv()
Expand Down
Empty file added athena/crud/__init__.py
Empty file.
16 changes: 16 additions & 0 deletions athena/crud/response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from __future__ import annotations

from sqlalchemy.orm import Session

from athena.models.api import Response


def get_response(session: Session, response_id: int) -> Response | None:
return session.query(Response).filter(Response.id == response_id).first()


def create_response(session: Session, response: Response) -> Response:
session.add(response)
session.commit()
session.refresh(response)
return response
15 changes: 15 additions & 0 deletions athena/db.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from contextlib import contextmanager

from sqlalchemy import MetaData, create_engine
from sqlalchemy.ext.declarative import declarative_base
Expand Down Expand Up @@ -28,3 +29,17 @@ def get_db():
yield db
finally:
db.close()


@contextmanager
def session_scope():
"""Provide a transactional scope around a series of operations."""
session = SessionLocal()
try:
yield session
session.commit()
except:
session.rollback()
raise
finally:
session.close()
58 changes: 46 additions & 12 deletions athena/input_processor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import json

from loguru import logger

import plugins
from athena.llm.fast_chat.chat_completion import fastchat_chat_completion
from athena.llm.openai.completion import openai_completion
from athena.plugins.authentication_plugin import AuthenticationPlugin
from athena.plugins.plugin_base import PluginBase
from athena.plugins.plugin_manager import PluginManager
from athena.prompt import SYSTEM_PROMPT
from athena.prompt import ROLE_ASSISTANT_PROMPT, ROLE_SYSTEM_PROMPT, SYSTEM_PROMPT
from athena.user_manager import UserManager

user_manager = UserManager()
Expand All @@ -20,7 +23,7 @@ def process_input(
completion_callback=None,
):
if not completion_callback:
completion_callback = openai_completion
completion_callback = fastchat_chat_completion
logger.info(f"User input received: {user_input}")
if user_input == "":
return "I'm sorry, I didn't receive any input. Can you please try again?"
Expand Down Expand Up @@ -51,33 +54,64 @@ def process_input(
logger.debug("No personalization data found.")

response = plugin_manager.process_input(user_input)
prompt = f"{SYSTEM_PROMPT}{user_input}"
if completion_callback == openai_completion:
prompt = f"{SYSTEM_PROMPT}{user_input}"
else:
prompt = [
{"role": "system", "content": ROLE_SYSTEM_PROMPT},
{"role": "assistant", "content": ROLE_ASSISTANT_PROMPT},
{"role": "user", "content": user_input},
]
if response is None and intent_confidence < 0.65:
logger.debug(
"No plugin was able to process the input and the intent confidence is low. Using GPT-3 to generate a response."
)
response = completion_callback(prompt)
if completion_callback == fastchat_chat_completion:
response = completion_callback(
"fastchat-t5-3b-v1.0", prompt, temperature=0.8, max_tokens=512
)
else:
response = completion_callback(prompt)

if response is None:
logger.debug("No plugin was able to process the input. Using default logic.")
if intent == "greeting":
response = "Hello! How can I help you today?"
response = json.dumps(
{"choices": [{"text": "Hello! How can I help you today?"}]}
)
elif intent == "goodbye":
response = "Goodbye! Have a great day!"
response = json.dumps({"choices": [{"text": "Goodbye! Have a great day!"}]})
elif intent == "current_state":
response = "I've been busy!"
response = json.dumps({"choices": [{"text": "I've been busy!"}]})
elif intent == "name":
response = "My name is Athena!"
response = json.dumps({"choices": [{"text": "My name is Athena!"}]})
elif intent == "weather":
location = "unknown"
if entities:
for entity in entities:
if entity[0] == "GPE":
location = entities[0][-1]
response = f"I'm not currently able to check the weather, but you asked about {location}."
response = json.dumps(
{
"choices": [
{
"text": f"I'm not currently able to check the weather, but you asked about {location}."
}
]
}
)
else:
logger.debug("No intent was detected. Using GPT-3 to generate a response.")

response = completion_callback(prompt)
if completion_callback == fastchat_chat_completion:
logger.debug(
"No intent was detected. Using fast-chat to generate a response."
)
response = completion_callback(
"fastchat-t5-3b-v1.0", prompt, temperature=0.8, max_tokens=512
)
else:
logger.debug(
"No intent was detected. Using GPT-3 to generate a response."
)
response = completion_callback(prompt)
logger.info(f"Response generated: {response}")
return response
Empty file.
19 changes: 19 additions & 0 deletions athena/llm/fast_chat/chat_completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import requests


def fastchat_chat_completion(
model,
messages,
temperature=0.8,
max_tokens=512,
):
resp = requests.post(
"http://fastchat-api-server:8000/v1/chat/completions",
json={
"model": model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
},
)
return resp.text
55 changes: 55 additions & 0 deletions athena/llm/openai/chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,62 @@
import os

import openai
from loguru import logger
from tenacity import retry, stop_after_attempt, wait_random_exponential


@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def chat_completion_with_backoff(**kwargs):
return openai.ChatCompletion.create(**kwargs)


def openai_chat_completion(
prompt: str,
model: str = "gpt-3.5-turbo-0301",
max_tokens: int = 512,
temperature: float = 0.8,
stop: str = "\nHuman:",
n: int = 1,
best_of: int = 1,
) -> str:
"""Generates OpenAI Completion using the provided parameters.

Args:
prompt (str): The prompt for generating the completion.
model (str, optional): The name of the model to use for generating the completion. Defaults to "gpt-3.5-turbo-0301".
max_tokens (int, optional): The maximum number of tokens to generate in the completion. Defaults to 512.
temperature (float, optional): The temperature to use for generating the completion. Defaults to 0.8.
stop (str, optional): The sequence where the model should stop generating further tokens. Defaults to "\nHuman:".
n (int, optional): The number of completions to generate. Defaults to 1.
best_of (int, optional): The number of best completions to return. Defaults to 1.

Returns:
str: The generated completion.
"""

logger.debug(f"Generating OpenAI Completion using prompt: {prompt}")
logger.debug(
f"Model: {model}, max_tokens: {max_tokens}, "
f"temperature: {temperature}, stop: {stop}, n: {n}"
)

if not os.environ.get("OPENAI_API_KEY"):
raise ValueError("OpenAI API key is not available in the environment variable.")

openai.api_key = os.environ["OPENAI_API_KEY"]
try:
response = chat_completion_with_backoff(
engine=model,
prompt=prompt,
max_tokens=max_tokens,
n=n,
stop=stop,
best_of=best_of,
temperature=temperature,
)
result = response.choices[0].text.strip()
logger.debug(f"OpenAI Completion: {result}")
return result
except Exception as e:
logger.exception(f"Error in generating OpenAI Completion: {e}")
return "Error"
21 changes: 15 additions & 6 deletions athena/llm/openai/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
from loguru import logger
from tenacity import retry, stop_after_attempt, wait_random_exponential

from athena.llm.openai.embedding import get_openai_embedding


def openai_completion_cache_key(**kwargs):
pass


@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def completion_with_backoff(**kwargs):
Expand All @@ -18,7 +24,7 @@ def openai_completion(
stop: str = "\nHuman:",
n: int = 1,
best_of: int = 1,
) -> str:
):
"""Generates OpenAI Completion using the provided parameters.

Args:
Expand All @@ -31,7 +37,7 @@ def openai_completion(
best_of (int, optional): The number of best completions to return. Defaults to 1.

Returns:
str: The generated completion.
response: The generated completion.
"""

logger.debug(f"Generating OpenAI Completion using prompt: {prompt}")
Expand All @@ -44,6 +50,10 @@ def openai_completion(
raise ValueError("OpenAI API key is not available in the environment variable.")

openai.api_key = os.environ["OPENAI_API_KEY"]
split_prompt = prompt.split("Human:")
embeddings = get_openai_embedding(split_prompt[-1])
logger.debug(embeddings)

try:
response = completion_with_backoff(
engine=model,
Expand All @@ -54,9 +64,8 @@ def openai_completion(
best_of=best_of,
temperature=temperature,
)
result = response.choices[0].text.strip()
logger.debug(f"OpenAI Completion: {result}")
return result
logger.debug(f"OpenAI Completion: {response}")
return response
except Exception as e:
logger.exception(f"Error in generating OpenAI Completion: {e}")
return "Error"
raise e from e
Loading