From de8efea752d6f6fa075792ac07ece89187544ef0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Fri, 7 Jul 2023 20:37:06 +0200 Subject: [PATCH 01/16] Feat: Add usage metric Refacto: Extract Azure Content Safety and Azure OpenAI from "main.py" file --- README.md | 2 +- src/conversation-api/ai/contentsafety.py | 79 ++++++ src/conversation-api/ai/openai.py | 132 ++++++++++ src/conversation-api/main.py | 272 +++++++------------- src/conversation-api/models/conversation.py | 2 +- src/conversation-api/models/message.py | 2 +- src/conversation-api/models/usage.py | 13 + src/conversation-api/persistence/cosmos.py | 5 + src/conversation-api/persistence/isearch.py | 4 +- src/conversation-api/persistence/istore.py | 5 + src/conversation-api/persistence/qdrant.py | 48 +--- src/conversation-api/persistence/redis.py | 10 +- 12 files changed, 342 insertions(+), 232 deletions(-) create mode 100644 src/conversation-api/ai/contentsafety.py create mode 100644 src/conversation-api/ai/openai.py create mode 100644 src/conversation-api/models/usage.py diff --git a/README.md b/README.md index 23706533..e943fdf3 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ db = 0 host = "localhost" [cosmos] -# Containers "conversation" (/user_id), "message" (/conversation_id) and "user" (/dummy) must exist +# Containers "conversation" (/user_id), "message" (/conversation_id), "user" (/dummy), "usage" (/user_id) must exist url = "https://private-gpt.documents.azure.com:443" database = "private-gpt" ``` diff --git a/src/conversation-api/ai/contentsafety.py b/src/conversation-api/ai/contentsafety.py new file mode 100644 index 00000000..ffbde09b --- /dev/null +++ b/src/conversation-api/ai/contentsafety.py @@ -0,0 +1,79 @@ +# Import utils +from utils import (build_logger, get_config) + +# Import misc +from azure.core.credentials import AzureKeyCredential +from fastapi import HTTPException, status +from tenacity import retry, stop_after_attempt, wait_random_exponential +import azure.ai.contentsafety as azure_cs +import azure.core.exceptions as azure_exceptions + + +### +# Init misc +### + +logger = build_logger(__name__) + +### +# Init Azure Content Safety +### + +# Score are following: 0 - Safe, 2 - Low, 4 - Medium, 6 - High +# See: https://review.learn.microsoft.com/en-us/azure/cognitive-services/content-safety/concepts/harm-categories?branch=release-build-content-safety#severity-levels +ACS_SEVERITY_THRESHOLD = 2 +ACS_API_BASE = get_config("acs", "api_base", str, required=True) +ACS_API_TOKEN = get_config("acs", "api_token", str, required=True) +ACS_MAX_LENGTH = get_config("acs", "max_length", int, required=True) +logger.info(f"Connected Azure Content Safety to {ACS_API_BASE}") +acs_client = azure_cs.ContentSafetyClient( + ACS_API_BASE, AzureKeyCredential(ACS_API_TOKEN) +) + + +class ContentSafety: + @retry( + reraise=True, + stop=stop_after_attempt(3), + wait=wait_random_exponential(multiplier=0.5, max=30), + ) + async def is_moderated(self, prompt: str) -> bool: + logger.debug(f"Checking moderation for text: {prompt}") + + if len(prompt) > ACS_MAX_LENGTH: + logger.info(f"Message ({len(prompt)}) too long for moderation") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Message too long", + ) + + req = azure_cs.models.AnalyzeTextOptions( + text=prompt, + categories=[ + azure_cs.models.TextCategory.HATE, + azure_cs.models.TextCategory.SELF_HARM, + azure_cs.models.TextCategory.SEXUAL, + azure_cs.models.TextCategory.VIOLENCE, + ], + ) + + try: + res = acs_client.analyze_text(req) + except azure_exceptions.ClientAuthenticationError as e: + logger.exception(e) + return False + + is_moderated = any( + cat.severity >= ACS_SEVERITY_THRESHOLD + for cat in [ + res.hate_result, + res.self_harm_result, + res.sexual_result, + res.violence_result, + ] + ) + if is_moderated: + logger.info(f"Message is moderated: {prompt}") + logger.debug(f"Moderation result: {res}") + + return is_moderated diff --git a/src/conversation-api/ai/openai.py b/src/conversation-api/ai/openai.py new file mode 100644 index 00000000..4434267a --- /dev/null +++ b/src/conversation-api/ai/openai.py @@ -0,0 +1,132 @@ +# Import utils +from uuid import UUID +from utils import (build_logger, get_config, hash_token) + +# Import misc +from azure.identity import DefaultAzureCredential +from models.user import UserModel +from tenacity import retry, stop_after_attempt, wait_random_exponential +from typing import Any, Dict, List, AsyncGenerator, Union +import asyncio +import openai + + +### +# Init misc +### + +logger = build_logger(__name__) + +### +# Init OpenIA +### + +async def refresh_oai_token(): + """ + Refresh OpenAI token every 15 minutes. + + The OpenAI SDK does not support token refresh, so we need to do it manually. We passe manually the token to the SDK. Azure AD tokens are valid for 30 mins, but we refresh every 15 minutes to be safe. + + See: https://github.com/openai/openai-python/pull/350#issuecomment-1489813285 + """ + while True: + logger.info("Refreshing OpenAI token") + oai_cred = DefaultAzureCredential() + oai_token = oai_cred.get_token("https://cognitiveservices.azure.com/.default") + openai.api_key = oai_token.token + # Execute every 20 minutes + await asyncio.sleep(15 * 60) + + +openai.api_base = get_config("openai", "api_base", str, required=True) +openai.api_type = "azure_ad" +openai.api_version = "2023-05-15" +logger.info(f"Using Aure private service ({openai.api_base})") +asyncio.create_task(refresh_oai_token()) + +OAI_GPT_DEPLOY_ID = get_config("openai", "gpt_deploy_id", str, required=True) +OAI_GPT_MAX_TOKENS = get_config("openai", "gpt_max_tokens", int, required=True) +OAI_GPT_MODEL = get_config( + "openai", "gpt_model", str, default="gpt-3.5-turbo", required=True +) +logger.info( + f'Using OpenAI ADA model "{OAI_GPT_MODEL}" ({OAI_GPT_DEPLOY_ID}) with {OAI_GPT_MAX_TOKENS} tokens max' +) + +OAI_ADA_DEPLOY_ID = get_config("openai", "ada_deploy_id", str, required=True) +OAI_ADA_MAX_TOKENS = get_config("openai", "ada_max_tokens", int, required=True) +OAI_ADA_MODEL = get_config( + "openai", "ada_model", str, default="text-embedding-ada-002", required=True +) +logger.info( + f'Using OpenAI ADA model "{OAI_ADA_MODEL}" ({OAI_ADA_DEPLOY_ID}) with {OAI_ADA_MAX_TOKENS} tokens max' +) + + +class OpenAI: + @retry( + reraise=True, + stop=stop_after_attempt(3), + wait=wait_random_exponential(multiplier=0.5, max=30), + ) + async def vector_from_text(self, prompt: str, user_id: UUID) -> List[float]: + logger.debug(f"Getting vector for text: {prompt}") + try: + res = openai.Embedding.create( + deployment_id=OAI_ADA_DEPLOY_ID, + input=prompt, + model=OAI_ADA_MODEL, + user=user_id.hex, + ) + except openai.error.AuthenticationError as e: + logger.exception(e) + return [] + + return res.data[0].embedding + + @retry( + reraise=True, + stop=stop_after_attempt(3), + wait=wait_random_exponential(multiplier=0.5, max=30), + ) + async def completion(self, messages: List[Dict[str, str]], current_user: UserModel) -> Union[str, None]: + try: + # Use chat completion to get a more natural response and lower the usage cost + completion = openai.ChatCompletion.create( + deployment_id=OAI_GPT_DEPLOY_ID, + messages=messages, + model=OAI_GPT_MODEL, + presence_penalty=1, # Increase the model's likelihood to talk about new topics + user=hash_token(current_user.id.bytes).hex, + ) + content = completion["choices"][0].message.content + except openai.error.AuthenticationError as e: + logger.exception(e) + return + + return content + + @retry( + reraise=True, + stop=stop_after_attempt(3), + wait=wait_random_exponential(multiplier=0.5, max=30), + ) + async def completion_stream(self, messages: List[Dict[str, str]], current_user: UserModel) -> AsyncGenerator[Any, None]: + try: + # Use chat completion to get a more natural response and lower the usage cost + chunks = openai.ChatCompletion.create( + deployment_id=OAI_GPT_DEPLOY_ID, + messages=messages, + model=OAI_GPT_MODEL, + presence_penalty=1, # Increase the model's likelihood to talk about new topics + stream=True, + user=hash_token(current_user.id.bytes).hex, + ) + except openai.error.AuthenticationError as e: + logger.exception(e) + return + + for chunk in chunks: + content = chunk["choices"][0].get("delta", {}).get("content") + if content is not None: + yield content diff --git a/src/conversation-api/main.py b/src/conversation-api/main.py index 48ff6bcc..f286ff05 100644 --- a/src/conversation-api/main.py +++ b/src/conversation-api/main.py @@ -9,8 +9,8 @@ ) # Import misc -from azure.core.credentials import AzureKeyCredential -from azure.identity import DefaultAzureCredential +from ai.contentsafety import ContentSafety +from ai.openai import OpenAI, OAI_GPT_MODEL, OAI_GPT_MAX_TOKENS, OAI_ADA_MODEL, OAI_ADA_MAX_TOKENS from datetime import datetime from fastapi import FastAPI, HTTPException, status, Request, Depends from fastapi.middleware.cors import CORSMiddleware @@ -19,20 +19,17 @@ from models.message import MessageModel, MessageRole, StoredMessageModel from models.prompt import StoredPromptModel, ListPromptsModel from models.search import SearchModel +from models.usage import UsageModel from models.user import UserModel from persistence.isearch import SearchImplementation from persistence.istore import StoreImplementation from persistence.istream import StreamImplementation from sse_starlette.sse import EventSourceResponse -from tenacity import retry, stop_after_attempt, wait_random_exponential from typing import Annotated, Dict, List, Optional from uuid import UUID from uuid import uuid4 import asyncio -import azure.ai.contentsafety as azure_cs -import azure.core.exceptions as azure_exceptions import csv -import openai ### @@ -44,6 +41,7 @@ ### # Init persistence ### + store_impl = get_config("persistence", "store", StoreImplementation, required=True) if store_impl == StoreImplementation.COSMOS: logger.info("Using CosmosDB store") @@ -72,58 +70,6 @@ else: raise ValueError(f"Unknown stream implementation: {stream_impl}") -### -# Init OpenAI -### - - -async def refresh_oai_token(): - """ - Refresh OpenAI token every 15 minutes. - - The OpenAI SDK does not support token refresh, so we need to do it manually. We passe manually the token to the SDK. Azure AD tokens are valid for 30 mins, but we refresh every 15 minutes to be safe. - - See: https://github.com/openai/openai-python/pull/350#issuecomment-1489813285 - """ - while True: - logger.info("Refreshing OpenAI token") - oai_cred = DefaultAzureCredential() - oai_token = oai_cred.get_token("https://cognitiveservices.azure.com/.default") - openai.api_key = oai_token.token - # Execute every 20 minutes - await asyncio.sleep(15 * 60) - - -OAI_GPT_DEPLOY_ID = get_config("openai", "gpt_deploy_id", str, required=True) -OAI_GPT_MAX_TOKENS = get_config("openai", "gpt_max_tokens", int, required=True) -OAI_GPT_MODEL = get_config( - "openai", "gpt_model", str, default="gpt-3.5-turbo", required=True -) -logger.info( - f'Using OpenAI ADA model "{OAI_GPT_MODEL}" ({OAI_GPT_DEPLOY_ID}) with {OAI_GPT_MAX_TOKENS} tokens max' -) - -openai.api_base = get_config("openai", "api_base", str, required=True) -openai.api_type = "azure_ad" -openai.api_version = "2023-05-15" -logger.info(f"Using Aure private service ({openai.api_base})") -asyncio.create_task(refresh_oai_token()) - -### -# Init Azure Content Safety -### - -# Score are following: 0 - Safe, 2 - Low, 4 - Medium, 6 - High -# See: https://review.learn.microsoft.com/en-us/azure/cognitive-services/content-safety/concepts/harm-categories?branch=release-build-content-safety#severity-levels -ACS_SEVERITY_THRESHOLD = 2 -ACS_API_BASE = get_config("acs", "api_base", str, required=True) -ACS_API_TOKEN = get_config("acs", "api_token", str, required=True) -ACS_MAX_LENGTH = get_config("acs", "max_length", int, required=True) -logger.info(f"Connected Azure Content Safety to {ACS_API_BASE}") -acs_client = azure_cs.ContentSafetyClient( - ACS_API_BASE, AzureKeyCredential(ACS_API_TOKEN) -) - ### # Init FastAPI ### @@ -159,6 +105,8 @@ async def refresh_oai_token(): # Init Generative AI ### +openai = OpenAI() +content_safety = ContentSafety() def get_ai_prompt() -> Dict[UUID, StoredPromptModel]: prompts = {} @@ -312,7 +260,7 @@ async def message_post( conversation_id: Optional[UUID] = None, prompt_id: Optional[UUID] = None, ) -> GetConversationModel: - if await is_moderated(content): + if await content_safety.is_moderated(content): logger.info(f"Message content is moderated: {content}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -348,24 +296,22 @@ async def message_post( token=uuid4(), ) - # Validate message length - tokens_nb = oai_tokens_nb( - message.content - + "".join([m.content for m in store.message_list(message.conversation_id)]), - OAI_GPT_MODEL, - ) + tokens_nb = await _validate_message_length(message=message) - logger.debug(f"{tokens_nb} tokens in the conversation") - if tokens_nb > OAI_GPT_MAX_TOKENS: - logger.info(f"Message ({tokens_nb}) too long for conversation") - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Conversation history is too long", - ) + # Build usage + usage = UsageModel( + ai_model=OAI_GPT_MODEL, + conversation_id=conversation_id, + created_at=datetime.now(), + id=uuid4(), + tokens=tokens_nb, + user_id=current_user.id, + ) + store.usage_set(usage) # Update conversation store.message_set(message) - index.message_index(message, current_user.id) + await _message_index(message, current_user) conversation = store.conversation_get(conversation_id, current_user.id) if not conversation: logger.warn("ACID error: conversation not found after testing existence") @@ -381,15 +327,7 @@ async def message_post( detail="Prompt ID not found", ) - # Validate message length - tokens_nb = oai_tokens_nb(content, OAI_GPT_MODEL) - logger.debug(f"{tokens_nb} tokens in the conversation") - if tokens_nb > OAI_GPT_MAX_TOKENS: - logger.info(f"Message ({tokens_nb}) too long for conversation") - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Conversation history is too long", - ) + tokens_nb = await _validate_message_length(content=content) # Build conversation conversation = StoredConversationModel( @@ -400,6 +338,17 @@ async def message_post( ) store.conversation_set(conversation) + # Build usage + usage = UsageModel( + ai_model=OAI_GPT_MODEL, + conversation_id=conversation.id, + created_at=datetime.now(), + id=uuid4(), + tokens=tokens_nb, + user_id=current_user.id, + ) + store.usage_set(usage) + # Build message message = StoredMessageModel( content=content, @@ -411,19 +360,15 @@ async def message_post( token=uuid4(), ) store.message_set(message) - index.message_index(message, current_user.id) + await _message_index(message, current_user) messages = store.message_list(conversation.id) if conversation.title is None: - asyncio.get_running_loop().run_in_executor( - None, lambda: guess_title(conversation, messages, current_user) - ) + asyncio.create_task(_guess_title(conversation, messages, current_user)) # Execute the message completion - asyncio.get_running_loop().run_in_executor( - None, lambda: completion_from_conversation(conversation, messages, current_user) - ) + asyncio.create_task(_completion_from_conversation(conversation, messages, current_user)) return GetConversationModel( **conversation.dict(), @@ -433,10 +378,10 @@ async def message_post( @api.get("/message/{id}") async def message_get(id: UUID, token: UUID, req: Request) -> EventSourceResponse: - return EventSourceResponse(read_message_sse(req, token)) + return EventSourceResponse(_read_message_sse(req, token)) -async def read_message_sse(req: Request, message_id: UUID): +async def _read_message_sse(req: Request, message_id: UUID): def clean(): logger.info(f"Cleared message cache (message_id={message_id})") stream.clean(message_id) @@ -463,15 +408,10 @@ async def loop_func() -> bool: async def message_search( q: str, current_user: Annotated[UserModel, Depends(get_current_user)] ) -> SearchModel: - return index.message_search(q, current_user.id) + return await index.message_search(q, current_user.id) -@retry( - reraise=True, - stop=stop_after_attempt(3), - wait=wait_random_exponential(multiplier=0.5, max=30), -) -def completion_from_conversation( +async def _completion_from_conversation( conversation: StoredConversationModel, messages: List[MessageModel], current_user: UserModel, @@ -496,28 +436,12 @@ def completion_from_conversation( logger.debug(f"Completion messages: {completion_messages}") - try: - # Use chat completion to get a more natural response and lower the usage cost - chunks = openai.ChatCompletion.create( - deployment_id=OAI_GPT_DEPLOY_ID, - messages=completion_messages, - model=OAI_GPT_MODEL, - presence_penalty=1, # Increase the model's likelihood to talk about new topics - stream=True, - user=hash_token(current_user.id.bytes).hex, - ) - except openai.error.AuthenticationError as e: - logger.exception(e) - return - content_full = "" - for chunk in chunks: - content = chunk["choices"][0].get("delta", {}).get("content") - if content is not None: - logger.debug(f"Completion result: {content}") - # Add content to the redis stream cache_key - stream.push(content, last_message.token) - content_full += content + async for content in openai.completion_stream(completion_messages, current_user): + logger.debug(f"Completion result: {content}") + # Add content to the redis stream cache_key + stream.push(content, last_message.token) + content_full += content # First, store the updated conversation in Redis res_message = StoredMessageModel( @@ -529,18 +453,53 @@ def completion_from_conversation( secret=last_message.secret, ) store.message_set(res_message) - index.message_index(res_message, current_user.id) + await _message_index(res_message, current_user) # Then, send the end of stream message stream.push(STREAM_STOPWORD, last_message.token) -@retry( - reraise=True, - stop=stop_after_attempt(3), - wait=wait_random_exponential(multiplier=0.5, max=30), -) -def guess_title( +async def _message_index(message: StoredMessageModel, current_user: UserModel) -> None: + usage = UsageModel( + ai_model=OAI_ADA_MODEL, + conversation_id=message.conversation_id, + created_at=datetime.now(), + id=uuid4(), + tokens=oai_tokens_nb(message.content, OAI_ADA_MODEL), + user_id=current_user.id, + ) + store.usage_set(usage) + await index.message_index(message, current_user.id) + + +async def _validate_message_length( + message: Optional[StoredMessageModel] = None, + content: Optional[str] = None, +) -> int: + if content: + tokens_nb = oai_tokens_nb(content, OAI_GPT_MODEL) + elif message: + tokens_nb = oai_tokens_nb( + message.content + + "".join([m.content for m in store.message_list(message.conversation_id)]), + OAI_GPT_MODEL, + ) + else: + raise ValueError('Either message or content must be provided to "validate_usage"') + + logger.debug(f"{tokens_nb} tokens in the conversation") + + if tokens_nb > OAI_GPT_MAX_TOKENS: + logger.info(f"Message ({tokens_nb}) too long for conversation") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Conversation history is too long", + ) + + return tokens_nb + + +async def _guess_title( conversation: StoredConversationModel, messages: List[MessageModel], current_user: UserModel, @@ -556,67 +515,8 @@ def guess_title( logger.debug(f"Completion messages: {completion_messages}") - try: - # Use chat completion to get a more natural response and lower the usage cost - completion = openai.ChatCompletion.create( - deployment_id=OAI_GPT_DEPLOY_ID, - messages=completion_messages, - model=OAI_GPT_MODEL, - presence_penalty=1, # Increase the model's likelihood to talk about new topics - user=hash_token(current_user.id.bytes).hex, - ) - content = completion["choices"][0].message.content - except openai.error.AuthenticationError as e: - logger.exception(e) - return + content = await openai.completion(completion_messages, current_user) # Store the updated conversation in Redis conversation.title = content store.conversation_set(conversation) - - -@retry( - reraise=True, - stop=stop_after_attempt(3), - wait=wait_random_exponential(multiplier=0.5, max=30), -) -async def is_moderated(prompt: str) -> bool: - logger.debug(f"Checking moderation for text: {prompt}") - - if len(prompt) > ACS_MAX_LENGTH: - logger.info(f"Message ({len(prompt)}) too long for moderation") - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Message too long", - ) - - req = azure_cs.models.AnalyzeTextOptions( - text=prompt, - categories=[ - azure_cs.models.TextCategory.HATE, - azure_cs.models.TextCategory.SELF_HARM, - azure_cs.models.TextCategory.SEXUAL, - azure_cs.models.TextCategory.VIOLENCE, - ], - ) - - try: - res = acs_client.analyze_text(req) - except azure_exceptions.ClientAuthenticationError as e: - logger.exception(e) - return False - - is_moderated = any( - cat.severity >= ACS_SEVERITY_THRESHOLD - for cat in [ - res.hate_result, - res.self_harm_result, - res.sexual_result, - res.violence_result, - ] - ) - if is_moderated: - logger.info(f"Message is moderated: {prompt}") - logger.debug(f"Moderation result: {res}") - - return is_moderated diff --git a/src/conversation-api/models/conversation.py b/src/conversation-api/models/conversation.py index 6ea5ae8e..09f3db4e 100644 --- a/src/conversation-api/models/conversation.py +++ b/src/conversation-api/models/conversation.py @@ -10,7 +10,7 @@ class BaseConversationModel(BaseModel): created_at: datetime id: UUID title: Optional[str] = None - user_id: UUID + user_id: UUID # Partition key class StoredConversationModel(BaseConversationModel): diff --git a/src/conversation-api/models/message.py b/src/conversation-api/models/message.py index ebb73411..d3b123fc 100644 --- a/src/conversation-api/models/message.py +++ b/src/conversation-api/models/message.py @@ -21,7 +21,7 @@ class MessageModel(BaseModel): class StoredMessageModel(MessageModel): - conversation_id: UUID + conversation_id: UUID # Partition key class IndexMessageModel(BaseModel): diff --git a/src/conversation-api/models/usage.py b/src/conversation-api/models/usage.py new file mode 100644 index 00000000..94483266 --- /dev/null +++ b/src/conversation-api/models/usage.py @@ -0,0 +1,13 @@ +from datetime import datetime +from pydantic import BaseModel +from uuid import UUID + + +class UsageModel(BaseModel): + ai_model: str + conversation_id: UUID + created_at: datetime + id: UUID + tokens: int + user_id: UUID # Partition key + diff --git a/src/conversation-api/persistence/cosmos.py b/src/conversation-api/persistence/cosmos.py index a1eca0a9..8438d7d6 100644 --- a/src/conversation-api/persistence/cosmos.py +++ b/src/conversation-api/persistence/cosmos.py @@ -11,6 +11,7 @@ from models.conversation import StoredConversationModel, StoredConversationModel from models.message import MessageModel, IndexMessageModel, StoredMessageModel from models.user import UserModel +from models.usage import UsageModel from typing import (Any, Dict, List, Union) from uuid import UUID @@ -30,6 +31,7 @@ conversation_client = database.get_container_client("conversation") message_client = database.get_container_client("message") user_client = database.get_container_client("user") +usage_client = database.get_container_client("usage") logger.info(f'Connected to Cosmos DB at "{DB_URL}"') @@ -106,6 +108,9 @@ def message_list(self, conversation_id: UUID) -> List[MessageModel]: items = message_client.query_items(query=query, enable_cross_partition_query=True) return [MessageModel(**item) for item in items] + def usage_set(self, usage: UsageModel) -> None: + usage_client.upsert_item(body=self._sanitize_before_insert(usage.dict())) + def _sanitize_before_insert(self, item: dict) -> Dict[str, Union[str, int, float, bool]]: for key, value in item.items(): if isinstance(value, UUID): diff --git a/src/conversation-api/persistence/isearch.py b/src/conversation-api/persistence/isearch.py index f399ae04..fdb69111 100644 --- a/src/conversation-api/persistence/isearch.py +++ b/src/conversation-api/persistence/isearch.py @@ -15,9 +15,9 @@ def __init__(self, store: IStore): self.store = store @abstractmethod - def message_search(self, query: str, user_id: UUID) -> SearchModel[MessageModel]: + async def message_search(self, query: str, user_id: UUID) -> SearchModel[MessageModel]: pass @abstractmethod - def message_index(self, message: StoredMessageModel, user_id: UUID) -> None: + async def message_index(self, message: StoredMessageModel, user_id: UUID) -> None: pass diff --git a/src/conversation-api/persistence/istore.py b/src/conversation-api/persistence/istore.py index 0433db74..5bd4a76b 100644 --- a/src/conversation-api/persistence/istore.py +++ b/src/conversation-api/persistence/istore.py @@ -3,6 +3,7 @@ from models.conversation import GetConversationModel, StoredConversationModel from models.message import MessageModel, IndexMessageModel, StoredMessageModel from models.user import UserModel +from models.usage import UsageModel from typing import List, Union from uuid import UUID @@ -58,3 +59,7 @@ def message_set(self, message: StoredMessageModel) -> None: @abstractmethod def message_list(self, conversation_id: UUID) -> List[MessageModel]: pass + + @abstractmethod + def usage_set(self, usage: UsageModel) -> None: + pass diff --git a/src/conversation-api/persistence/qdrant.py b/src/conversation-api/persistence/qdrant.py index a70c768c..0a10a0d9 100644 --- a/src/conversation-api/persistence/qdrant.py +++ b/src/conversation-api/persistence/qdrant.py @@ -4,21 +4,20 @@ # Import misc from .isearch import ISearch from .istore import IStore +from ai.openai import OpenAI from datetime import datetime from models.message import MessageModel, IndexMessageModel, StoredMessageModel from models.search import SearchModel, SearchStatsModel, SearchAnswerModel from qdrant_client import QdrantClient -from tenacity import retry, stop_after_attempt, wait_random_exponential -from typing import List from uuid import UUID import asyncio -import openai import qdrant_client.http.models as qmodels import textwrap import time logger = build_logger(__name__) +openai = OpenAI() QD_COLLECTION = "messages" QD_DIMENSION = 1536 QD_HOST = get_config("qd", "host", str, required=True) @@ -27,15 +26,6 @@ client = QdrantClient(host=QD_HOST, port=6333) logger.info(f'Connected to Qdrant at "{QD_HOST}:{QD_PORT}"') -OAI_ADA_DEPLOY_ID = get_config("openai", "ada_deploy_id", str, required=True) -OAI_ADA_MAX_TOKENS = get_config("openai", "ada_max_tokens", int, required=True) -OAI_ADA_MODEL = get_config( - "openai", "ada_model", str, default="text-embedding-ada-002", required=True -) -logger.info( - f'Using OpenAI ADA model "{OAI_ADA_MODEL}" ({OAI_ADA_DEPLOY_ID}) with {OAI_ADA_MAX_TOKENS} tokens max' -) - class QdrantSearch(ISearch): def __init__(self, store: IStore): @@ -55,11 +45,11 @@ def __init__(self, store: IStore): ), ) - def message_search(self, q: str, user_id: UUID) -> SearchModel[MessageModel]: + async def message_search(self, q: str, user_id: UUID) -> SearchModel[MessageModel]: logger.debug(f"Searching for: {q}") start = time.monotonic() - vector = self._vector_from_text( + vector = await openai.vector_from_text( textwrap.dedent( f""" Today, we are the {datetime.now()}. {q.capitalize()} @@ -103,20 +93,18 @@ def message_search(self, q: str, user_id: UUID) -> SearchModel[MessageModel]: stats=SearchStatsModel(total=total, time=time.monotonic() - start), ) - def message_index( + async def message_index( self, message: StoredMessageModel, user_id: UUID ) -> None: logger.debug(f"Indexing message: {message.id}") - self._loop.run_in_executor( - None, lambda: self._index_worker(message, user_id) - ) + self._loop.create_task(self._index_worker(message, user_id)) - def _index_worker( + async def _index_worker( self, message: StoredMessageModel, user_id: UUID ) -> None: logger.debug(f"Starting indexing worker for message: {message.id}") - vector = self._vector_from_text(message.content, user_id) + vector = await openai.vector_from_text(message.content, user_id) index = IndexMessageModel( conversation_id=message.conversation_id, id=message.id, @@ -131,23 +119,3 @@ def _index_worker( vectors=[vector], ), ) - - @retry( - reraise=True, - stop=stop_after_attempt(3), - wait=wait_random_exponential(multiplier=0.5, max=30), - ) - def _vector_from_text(self, prompt: str, user_id: UUID) -> List[float]: - logger.debug(f"Getting vector for text: {prompt}") - try: - res = openai.Embedding.create( - deployment_id=OAI_ADA_DEPLOY_ID, - input=prompt, - model=OAI_ADA_MODEL, - user=user_id.hex, - ) - except openai.error.AuthenticationError as e: - logger.exception(e) - return [] - - return res.data[0].embedding diff --git a/src/conversation-api/persistence/redis.py b/src/conversation-api/persistence/redis.py index 4f2e0c74..1c2ab4f1 100644 --- a/src/conversation-api/persistence/redis.py +++ b/src/conversation-api/persistence/redis.py @@ -7,6 +7,7 @@ from models.conversation import StoredConversationModel, StoredConversationModel from models.message import MessageModel, IndexMessageModel, StoredMessageModel from models.user import UserModel +from models.usage import UsageModel from redis import Redis from typing import (Any, AsyncGenerator, Callable, Awaitable, List, Literal, Optional, Union) from uuid import UUID @@ -18,11 +19,12 @@ # Configuration CONVERSATION_PREFIX = "conversation" -MESSAGE_PREFIX = "message" DB_HOST = get_config("redis", "host", str, required=True) DB_PORT = 6379 +MESSAGE_PREFIX = "message" STREAM_PREFIX = "stream" STREAM_STOPWORD = "STOP" +USAGE_PREFIX = "usage" USER_PREFIX = "user" # Redis client @@ -129,6 +131,12 @@ def message_list(self, conversation_id: UUID) -> List[MessageModel]: messages.sort(key=lambda x: x.created_at) return messages + def usage_set(self, usage: UsageModel) -> None: + client.set(self._usage_cache_key(usage.user_id), usage.json()) + + def _usage_cache_key(self, user_id: UUID) -> str: + return f"{USAGE_PREFIX}:{user_id.hex}" + def _conversation_cache_key( self, user_id: UUID, conversation_id: Optional[UUID] = None ) -> str: From 5c7fffb18fd92a6c81639af2c1e83e47ce172408 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Fri, 7 Jul 2023 20:37:26 +0200 Subject: [PATCH 02/16] Doc: Make deps log WARN by default --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e943fdf3..b0f4e4fc 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ max_length = 1000 [logging] app_level = "DEBUG" -sys_level = "INFO" +sys_level = "WARN" [oidc] algorithms = ["RS256"] From 2cf97da50f4479c846c10ec36bb302640700f68a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Fri, 7 Jul 2023 20:37:40 +0200 Subject: [PATCH 03/16] UX: Make chat load a little quicker --- src/conversation-api/persistence/redis.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/conversation-api/persistence/redis.py b/src/conversation-api/persistence/redis.py index 1c2ab4f1..98c33e1c 100644 --- a/src/conversation-api/persistence/redis.py +++ b/src/conversation-api/persistence/redis.py @@ -196,7 +196,8 @@ async def get( if message_loop: yield message_loop - await asyncio.sleep(0.25) + # 8 messages per second, enough for give a good user experience, but not too much for not using the thread too much + await asyncio.sleep(0.125) # Send the end of stream message yield STREAM_STOPWORD From 01b412b266cd9f9329f883a3f524ffff0c5c21c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Fri, 7 Jul 2023 20:50:42 +0200 Subject: [PATCH 04/16] Quality: Add tokens length test for Ada embeddings --- src/conversation-api/main.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/conversation-api/main.py b/src/conversation-api/main.py index f286ff05..de0e3529 100644 --- a/src/conversation-api/main.py +++ b/src/conversation-api/main.py @@ -460,6 +460,14 @@ async def _completion_from_conversation( async def _message_index(message: StoredMessageModel, current_user: UserModel) -> None: + tokens_nb = oai_tokens_nb(message.content, OAI_ADA_MODEL) + if tokens_nb > OAI_ADA_MAX_TOKENS: + logger.info(f"Message ({tokens_nb}) too long for indexing") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Conversation history is too long", + ) + usage = UsageModel( ai_model=OAI_ADA_MODEL, conversation_id=message.conversation_id, From 3145cf2c1a2c41cfd6883bee03945113e500b3d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Fri, 7 Jul 2023 20:50:59 +0200 Subject: [PATCH 05/16] Refacto: Make clearer function name --- src/conversation-api/ai/openai.py | 4 ++-- src/conversation-api/main.py | 8 ++++---- src/conversation-api/persistence/qdrant.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/conversation-api/ai/openai.py b/src/conversation-api/ai/openai.py index 4434267a..1b812b3f 100644 --- a/src/conversation-api/ai/openai.py +++ b/src/conversation-api/ai/openai.py @@ -21,7 +21,7 @@ # Init OpenIA ### -async def refresh_oai_token(): +async def refresh_oai_token_background(): """ Refresh OpenAI token every 15 minutes. @@ -42,7 +42,7 @@ async def refresh_oai_token(): openai.api_type = "azure_ad" openai.api_version = "2023-05-15" logger.info(f"Using Aure private service ({openai.api_base})") -asyncio.create_task(refresh_oai_token()) +asyncio.create_task(refresh_oai_token_background()) OAI_GPT_DEPLOY_ID = get_config("openai", "gpt_deploy_id", str, required=True) OAI_GPT_MAX_TOKENS = get_config("openai", "gpt_max_tokens", int, required=True) diff --git a/src/conversation-api/main.py b/src/conversation-api/main.py index de0e3529..1895e538 100644 --- a/src/conversation-api/main.py +++ b/src/conversation-api/main.py @@ -365,10 +365,10 @@ async def message_post( messages = store.message_list(conversation.id) if conversation.title is None: - asyncio.create_task(_guess_title(conversation, messages, current_user)) + asyncio.create_task(_guess_title_background(conversation, messages, current_user)) # Execute the message completion - asyncio.create_task(_completion_from_conversation(conversation, messages, current_user)) + asyncio.create_task(_generate_completion_background(conversation, messages, current_user)) return GetConversationModel( **conversation.dict(), @@ -411,7 +411,7 @@ async def message_search( return await index.message_search(q, current_user.id) -async def _completion_from_conversation( +async def _generate_completion_background( conversation: StoredConversationModel, messages: List[MessageModel], current_user: UserModel, @@ -507,7 +507,7 @@ async def _validate_message_length( return tokens_nb -async def _guess_title( +async def _guess_title_background( conversation: StoredConversationModel, messages: List[MessageModel], current_user: UserModel, diff --git a/src/conversation-api/persistence/qdrant.py b/src/conversation-api/persistence/qdrant.py index 0a10a0d9..1f1979d2 100644 --- a/src/conversation-api/persistence/qdrant.py +++ b/src/conversation-api/persistence/qdrant.py @@ -97,9 +97,9 @@ async def message_index( self, message: StoredMessageModel, user_id: UUID ) -> None: logger.debug(f"Indexing message: {message.id}") - self._loop.create_task(self._index_worker(message, user_id)) + self._loop.create_task(self._index_background(message, user_id)) - async def _index_worker( + async def _index_background( self, message: StoredMessageModel, user_id: UUID ) -> None: logger.debug(f"Starting indexing worker for message: {message.id}") From 7a8efdffd08ad510912db7b232170db0a284b461 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Fri, 7 Jul 2023 22:02:00 +0200 Subject: [PATCH 06/16] Fix: Helm chart after new persistence config --- cicd/helm/private-gpt/templates/conversation-api-config.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cicd/helm/private-gpt/templates/conversation-api-config.yaml b/cicd/helm/private-gpt/templates/conversation-api-config.yaml index 5be68d57..e3ae590a 100644 --- a/cicd/helm/private-gpt/templates/conversation-api-config.yaml +++ b/cicd/helm/private-gpt/templates/conversation-api-config.yaml @@ -7,6 +7,11 @@ metadata: app.kubernetes.io/component: conversation-api data: config.toml: | + [persistence] + search = "qdrant" + store = "redis" + stream = "redis" + [api] root_path = "/{{ include "private-gpt.fullname" . }}-conversation-api" From efc52c161e1d780d65e7197e61aa260dfa18fd72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Fri, 7 Jul 2023 22:02:10 +0200 Subject: [PATCH 07/16] Doc: Add optional config --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index b0f4e4fc..5b4b25c2 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,9 @@ store = "cosmos" # Enum: "redis" stream = "redis" +[api] +root_path = "" + [openai] ada_deploy_id = "ada" ada_max_tokens = 2049 From 813065c79e9877f2d94850a42d3ae20c92fa0d90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Fri, 7 Jul 2023 22:15:32 +0200 Subject: [PATCH 08/16] Fix: Sort messages and conversations like Redis one --- src/conversation-api/persistence/cosmos.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/conversation-api/persistence/cosmos.py b/src/conversation-api/persistence/cosmos.py index 8438d7d6..e570f666 100644 --- a/src/conversation-api/persistence/cosmos.py +++ b/src/conversation-api/persistence/cosmos.py @@ -69,7 +69,7 @@ def conversation_set(self, conversation: StoredConversationModel) -> None: conversation_client.upsert_item(body=self._sanitize_before_insert(conversation.dict())) def conversation_list(self, user_id: UUID) -> List[StoredConversationModel]: - query = f"SELECT * FROM c WHERE c.user_id = '{user_id}'" + query = f"SELECT * FROM c WHERE c.user_id = '{user_id}' ORDER BY c.created_at DESC" items = conversation_client.query_items(query=query, enable_cross_partition_query=True) return [StoredConversationModel(**item) for item in items] @@ -104,7 +104,7 @@ def message_set(self, message: StoredMessageModel) -> None: }) def message_list(self, conversation_id: UUID) -> List[MessageModel]: - query = f"SELECT * FROM c WHERE c.conversation_id = '{conversation_id}'" + query = f"SELECT * FROM c WHERE c.conversation_id = '{conversation_id}' ORDER BY c.created_at ASC" items = message_client.query_items(query=query, enable_cross_partition_query=True) return [MessageModel(**item) for item in items] From 3508cc8c6fa89609b6f075590ca302b56fdb255f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Fri, 7 Jul 2023 22:15:46 +0200 Subject: [PATCH 09/16] Dev: Add Cosmos DB in the Helm chart --- .../helm/private-gpt/templates/conversation-api-config.yaml | 6 +++++- cicd/helm/private-gpt/values.yaml | 5 +++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/cicd/helm/private-gpt/templates/conversation-api-config.yaml b/cicd/helm/private-gpt/templates/conversation-api-config.yaml index e3ae590a..1fd11eac 100644 --- a/cicd/helm/private-gpt/templates/conversation-api-config.yaml +++ b/cicd/helm/private-gpt/templates/conversation-api-config.yaml @@ -9,7 +9,7 @@ data: config.toml: | [persistence] search = "qdrant" - store = "redis" + store = "cosmos" stream = "redis" [api] @@ -42,3 +42,7 @@ data: [redis] db = {{ .Values.redis.db | int }} host = "{{ include "common.names.fullname" .Subcharts.redis }}-master" + + [cosmos] + url = {{ .Values.cosmos.url | quote }} + database = {{ .Values.cosmos.database | quote }} diff --git a/cicd/helm/private-gpt/values.yaml b/cicd/helm/private-gpt/values.yaml index 04ac7fb6..b2f1b3b5 100644 --- a/cicd/helm/private-gpt/values.yaml +++ b/cicd/helm/private-gpt/values.yaml @@ -42,6 +42,11 @@ api: base: null gpt_deploy_id: gpt-35-turbo +cosmos: + # https://[db].documents.azure.com + url: null + database: null + redis: auth: enabled: false From 3e32320d5ef58b9883017ad5bcf3a9047b34a84b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Fri, 7 Jul 2023 22:35:14 +0200 Subject: [PATCH 10/16] Perf: Enhance background task managment --- src/conversation-api/ai/openai.py | 4 +++- src/conversation-api/main.py | 5 +++-- src/conversation-api/persistence/qdrant.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/conversation-api/ai/openai.py b/src/conversation-api/ai/openai.py index 1b812b3f..766d2294 100644 --- a/src/conversation-api/ai/openai.py +++ b/src/conversation-api/ai/openai.py @@ -16,6 +16,8 @@ ### logger = build_logger(__name__) +loop = asyncio.get_running_loop() + ### # Init OpenIA @@ -42,7 +44,7 @@ async def refresh_oai_token_background(): openai.api_type = "azure_ad" openai.api_version = "2023-05-15" logger.info(f"Using Aure private service ({openai.api_base})") -asyncio.create_task(refresh_oai_token_background()) +loop.create_task(refresh_oai_token_background()) OAI_GPT_DEPLOY_ID = get_config("openai", "gpt_deploy_id", str, required=True) OAI_GPT_MAX_TOKENS = get_config("openai", "gpt_max_tokens", int, required=True) diff --git a/src/conversation-api/main.py b/src/conversation-api/main.py index 1895e538..fe80cec2 100644 --- a/src/conversation-api/main.py +++ b/src/conversation-api/main.py @@ -37,6 +37,7 @@ ### logger = build_logger(__name__) +loop = asyncio.get_running_loop() ### # Init persistence @@ -365,10 +366,10 @@ async def message_post( messages = store.message_list(conversation.id) if conversation.title is None: - asyncio.create_task(_guess_title_background(conversation, messages, current_user)) + loop.create_task(_guess_title_background(conversation, messages, current_user)) # Execute the message completion - asyncio.create_task(_generate_completion_background(conversation, messages, current_user)) + loop.create_task(_generate_completion_background(conversation, messages, current_user)) return GetConversationModel( **conversation.dict(), diff --git a/src/conversation-api/persistence/qdrant.py b/src/conversation-api/persistence/qdrant.py index 1f1979d2..ecd2b4e5 100644 --- a/src/conversation-api/persistence/qdrant.py +++ b/src/conversation-api/persistence/qdrant.py @@ -31,7 +31,7 @@ class QdrantSearch(ISearch): def __init__(self, store: IStore): super().__init__(store) - self._loop = asyncio.new_event_loop() + self._loop = asyncio.get_running_loop() # Ensure collection exists try: From 43db362a4c816c6c303d3e327e599d144045cd42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Fri, 7 Jul 2023 22:41:40 +0200 Subject: [PATCH 11/16] Fix: Use UTC datetime --- src/conversation-api/main.py | 16 ++++++++-------- src/conversation-api/persistence/qdrant.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/conversation-api/main.py b/src/conversation-api/main.py index fe80cec2..8f6ba827 100644 --- a/src/conversation-api/main.py +++ b/src/conversation-api/main.py @@ -130,7 +130,7 @@ def get_ai_prompt() -> Dict[UUID, StoredPromptModel]: AI_PROMPTS = get_ai_prompt() AI_CONVERSATION_DEFAULT_PROMPT = f""" -Today, we are the {datetime.now()}. +Today, we are the {datetime.utcnow()}. You MUST: - Cite sources and examples as footnotes (example: [^1]) @@ -290,7 +290,7 @@ async def message_post( message = StoredMessageModel( content=content, conversation_id=conversation_id, - created_at=datetime.now(), + created_at=datetime.utcnow(), id=uuid4(), role=MessageRole.USER, secret=secret, @@ -303,7 +303,7 @@ async def message_post( usage = UsageModel( ai_model=OAI_GPT_MODEL, conversation_id=conversation_id, - created_at=datetime.now(), + created_at=datetime.utcnow(), id=uuid4(), tokens=tokens_nb, user_id=current_user.id, @@ -332,7 +332,7 @@ async def message_post( # Build conversation conversation = StoredConversationModel( - created_at=datetime.now(), + created_at=datetime.utcnow(), id=uuid4(), prompt=AI_PROMPTS[prompt_id] if prompt_id else None, user_id=current_user.id, @@ -343,7 +343,7 @@ async def message_post( usage = UsageModel( ai_model=OAI_GPT_MODEL, conversation_id=conversation.id, - created_at=datetime.now(), + created_at=datetime.utcnow(), id=uuid4(), tokens=tokens_nb, user_id=current_user.id, @@ -354,7 +354,7 @@ async def message_post( message = StoredMessageModel( content=content, conversation_id=conversation.id, - created_at=datetime.now(), + created_at=datetime.utcnow(), id=uuid4(), role=MessageRole.USER, secret=secret, @@ -448,7 +448,7 @@ async def _generate_completion_background( res_message = StoredMessageModel( content=content_full, conversation_id=conversation.id, - created_at=datetime.now(), + created_at=datetime.utcnow(), id=uuid4(), role=MessageRole.ASSISTANT, secret=last_message.secret, @@ -472,7 +472,7 @@ async def _message_index(message: StoredMessageModel, current_user: UserModel) - usage = UsageModel( ai_model=OAI_ADA_MODEL, conversation_id=message.conversation_id, - created_at=datetime.now(), + created_at=datetime.utcnow(), id=uuid4(), tokens=oai_tokens_nb(message.content, OAI_ADA_MODEL), user_id=current_user.id, diff --git a/src/conversation-api/persistence/qdrant.py b/src/conversation-api/persistence/qdrant.py index ecd2b4e5..77817b61 100644 --- a/src/conversation-api/persistence/qdrant.py +++ b/src/conversation-api/persistence/qdrant.py @@ -52,7 +52,7 @@ async def message_search(self, q: str, user_id: UUID) -> SearchModel[MessageMode vector = await openai.vector_from_text( textwrap.dedent( f""" - Today, we are the {datetime.now()}. {q.capitalize()} + Today, we are the {datetime.utcnow()}. {q.capitalize()} """ ), user_id, From 8410a5be6831a5b471f28a61d263c3a69cd72824 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Fri, 7 Jul 2023 23:15:14 +0200 Subject: [PATCH 12/16] Doc: Add latest features to features list --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 5b4b25c2..15cc2cac 100644 --- a/README.md +++ b/README.md @@ -10,11 +10,12 @@ Includes: - Deployable on any Kubernetes cluster, with its Helm chart - Manage users effortlessly with OpenID Connect - More than 150 tones and personalities (accountant, advisor, debater, excel sheet, instructor, logistician, etc.) to better help employees in their specific daily tasks -- Plug and play with any storage system, including [Azure Cosmos DB](https://learn.microsoft.com/en-us/azure/cosmos-db/), [Redis](https://github.com/redis/redis) and [Qdrant](https://github.com/qdrant/qdrant). +- Plug and play storage system, including [Azure Cosmos DB](https://learn.microsoft.com/en-us/azure/cosmos-db/), [Redis](https://github.com/redis/redis) and [Qdrant](https://github.com/qdrant/qdrant). - Possibility to send temporary messages, for confidentiality - Salable system based on stateless APIs, cache, progressive web app and events - Search engine for conversations, based on semantic similarity and AI embeddings -- Unlimited conversation history +- Unlimited conversation history and number of users +- Usage tracking, for better understanding of your employees' usage ![Application screenshot](docs/main.png) From 10563a66681e156de2e03558e9c404ebf26f6689 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Fri, 7 Jul 2023 23:15:36 +0200 Subject: [PATCH 13/16] Chore: Delete dead code --- src/conversation-api/persistence/cosmos.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/conversation-api/persistence/cosmos.py b/src/conversation-api/persistence/cosmos.py index e570f666..7b3b3ae8 100644 --- a/src/conversation-api/persistence/cosmos.py +++ b/src/conversation-api/persistence/cosmos.py @@ -3,15 +3,14 @@ # Import misc from .istore import IStore -from azure.cosmos import CosmosClient, PartitionKey, ConsistencyLevel -from azure.cosmos.database import DatabaseProxy -from azure.cosmos.exceptions import CosmosHttpResponseError, CosmosResourceExistsError +from azure.cosmos import CosmosClient, ConsistencyLevel +from azure.cosmos.exceptions import CosmosHttpResponseError from azure.identity import DefaultAzureCredential from datetime import datetime from models.conversation import StoredConversationModel, StoredConversationModel from models.message import MessageModel, IndexMessageModel, StoredMessageModel -from models.user import UserModel from models.usage import UsageModel +from models.user import UserModel from typing import (Any, Dict, List, Union) from uuid import UUID @@ -111,7 +110,7 @@ def message_list(self, conversation_id: UUID) -> List[MessageModel]: def usage_set(self, usage: UsageModel) -> None: usage_client.upsert_item(body=self._sanitize_before_insert(usage.dict())) - def _sanitize_before_insert(self, item: dict) -> Dict[str, Union[str, int, float, bool]]: + def _sanitize_before_insert(self, item: dict) -> Dict[str, Any]: for key, value in item.items(): if isinstance(value, UUID): item[key] = str(value) From 4b683a90de66c17a77d9c0d61469e9364638c777 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Fri, 7 Jul 2023 23:16:22 +0200 Subject: [PATCH 14/16] Fix: Documents with dict fails, like new message with prompt --- src/conversation-api/persistence/cosmos.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/conversation-api/persistence/cosmos.py b/src/conversation-api/persistence/cosmos.py index 7b3b3ae8..9c42f019 100644 --- a/src/conversation-api/persistence/cosmos.py +++ b/src/conversation-api/persistence/cosmos.py @@ -116,4 +116,8 @@ def _sanitize_before_insert(self, item: dict) -> Dict[str, Any]: item[key] = str(value) elif isinstance(value, datetime): item[key] = value.isoformat() + elif isinstance(value, dict): + item[key] = self._sanitize_before_insert(value) + elif isinstance(value, list): + item[key] = [self._sanitize_before_insert(i) for i in value] return item From a9a03a94e99be4dae406bc802bc6e2bf3e271cf7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Fri, 7 Jul 2023 23:17:00 +0200 Subject: [PATCH 15/16] Feat: Add prompt name in usage --- src/conversation-api/main.py | 31 +++++++++++++++------------- src/conversation-api/models/usage.py | 2 ++ 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/src/conversation-api/main.py b/src/conversation-api/main.py index 8f6ba827..5cc5989c 100644 --- a/src/conversation-api/main.py +++ b/src/conversation-api/main.py @@ -299,6 +299,17 @@ async def message_post( tokens_nb = await _validate_message_length(message=message) + # Update conversation + store.message_set(message) + conversation = store.conversation_get(conversation_id, current_user.id) + if not conversation: + logger.warn("ACID error: conversation not found after testing existence") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Conversation not found", + ) + await _message_index(message, current_user, conversation.prompt) + # Build usage usage = UsageModel( ai_model=OAI_GPT_MODEL, @@ -307,19 +318,9 @@ async def message_post( id=uuid4(), tokens=tokens_nb, user_id=current_user.id, + prompt_name=conversation.prompt.name if conversation.prompt else None, ) store.usage_set(usage) - - # Update conversation - store.message_set(message) - await _message_index(message, current_user) - conversation = store.conversation_get(conversation_id, current_user.id) - if not conversation: - logger.warn("ACID error: conversation not found after testing existence") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Conversation not found", - ) else: # Test prompt ID if provided if prompt_id and prompt_id not in AI_PROMPTS: @@ -347,6 +348,7 @@ async def message_post( id=uuid4(), tokens=tokens_nb, user_id=current_user.id, + prompt_name=conversation.prompt.name if conversation.prompt else None, ) store.usage_set(usage) @@ -361,7 +363,7 @@ async def message_post( token=uuid4(), ) store.message_set(message) - await _message_index(message, current_user) + await _message_index(message, current_user, conversation.prompt) messages = store.message_list(conversation.id) @@ -454,13 +456,13 @@ async def _generate_completion_background( secret=last_message.secret, ) store.message_set(res_message) - await _message_index(res_message, current_user) + await _message_index(res_message, current_user, conversation.prompt) # Then, send the end of stream message stream.push(STREAM_STOPWORD, last_message.token) -async def _message_index(message: StoredMessageModel, current_user: UserModel) -> None: +async def _message_index(message: StoredMessageModel, current_user: UserModel, prompt: Optional[StoredPromptModel]) -> None: tokens_nb = oai_tokens_nb(message.content, OAI_ADA_MODEL) if tokens_nb > OAI_ADA_MAX_TOKENS: logger.info(f"Message ({tokens_nb}) too long for indexing") @@ -476,6 +478,7 @@ async def _message_index(message: StoredMessageModel, current_user: UserModel) - id=uuid4(), tokens=oai_tokens_nb(message.content, OAI_ADA_MODEL), user_id=current_user.id, + prompt_name=prompt.name if prompt else None, ) store.usage_set(usage) await index.message_index(message, current_user.id) diff --git a/src/conversation-api/models/usage.py b/src/conversation-api/models/usage.py index 94483266..aec7bbb4 100644 --- a/src/conversation-api/models/usage.py +++ b/src/conversation-api/models/usage.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Optional from pydantic import BaseModel from uuid import UUID @@ -8,6 +9,7 @@ class UsageModel(BaseModel): conversation_id: UUID created_at: datetime id: UUID + prompt_name: Optional[str] = None tokens: int user_id: UUID # Partition key From 258999836378942e5646635d1ca18f4ade75be99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mence=20Lesn=C3=A9?= Date: Fri, 7 Jul 2023 23:26:33 +0200 Subject: [PATCH 16/16] Dev: Make login log more concise --- src/conversation-api/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/conversation-api/main.py b/src/conversation-api/main.py index 5cc5989c..2c1eb605 100644 --- a/src/conversation-api/main.py +++ b/src/conversation-api/main.py @@ -211,7 +211,7 @@ async def get_current_user( raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) user = store.user_get(sub) - logger.info(f"User logged in: {user}") + logger.info(f'User "{user.id}" logged in') logger.debug(f"JWT: {jwt}") if user: return user