Skip to content

Commit

Permalink
Merge branch 'develop' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
clemlesne committed Jul 7, 2023
2 parents 0bcf84c + 0a6c339 commit f69f6f9
Show file tree
Hide file tree
Showing 20 changed files with 365 additions and 109 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ Includes:
- Can be configured to use any Azure OpenAI completion API, including GPT-4
- Dark theme for better readability
- Dead simple interface
- Deployable on any Kubernetes cluster, with its Helm chart
- Manage users effortlessly with OpenID Connect
- More than 150 tones and personalities (accountant, advisor, debater, excel sheet, instructor, logistician, etc.) to better help employees in their specific daily tasks
- Plug and play with any storage system, including [Azure Cosmos DB](https://learn.microsoft.com/en-us/azure/cosmos-db/), [Redis](https://github.com/redis/redis) and [Qdrant](https://github.com/qdrant/qdrant).
- Possibility to send temporary messages, for confidentiality
- Salable system based on stateless APIs, cache, progressive web app and events
- Search engine for conversations, based on semantic similarity and AI embeddings
Expand All @@ -25,6 +27,14 @@ Create a local configuration file, a file named `config.toml` at the root of the
```toml
# config.toml
# Values are for example only, you should change them
[persistence]
# Enum: "qdrant"
search = "qdrant"
# Enum: "redis", "cosmos"
store = "cosmos"
# Enum: "redis"
stream = "redis"

[openai]
ada_deploy_id = "ada"
ada_max_tokens = 2049
Expand Down Expand Up @@ -53,6 +63,11 @@ host = "localhost"
[redis]
db = 0
host = "localhost"

[cosmos]
# Containers "conversation" (/user_id), "message" (/conversation_id) and "user" (/dummy) must exist
url = "https://private-gpt.documents.azure.com:443"
database = "private-gpt"
```

Now, you can either run the application as container or with live reload. For development, it is recommended to use live reload. For demo, it is recommended to use the container.
Expand Down
7 changes: 5 additions & 2 deletions cicd/helm/private-gpt/templates/conversation-api-ingress.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ spec:
middlewares:
- name: {{ include "private-gpt.fullname" . }}-conversation-api-prefix
- name: {{ include "private-gpt.fullname" . }}-conversation-api-security
- name: {{ include "private-gpt.fullname" . }}-conversation-api-ratelimit
- name: {{ include "private-gpt.fullname" . }}-conversation-api-ratelimit-auth
- name: {{ include "private-gpt.fullname" . }}-conversation-api-compress
tls:
{{- toYaml .Values.ingress.tls | nindent 4 }}
Expand Down Expand Up @@ -49,14 +49,17 @@ spec:
apiVersion: traefik.containo.us/v1alpha1
kind: Middleware
metadata:
name: {{ include "private-gpt.fullname" . }}-conversation-api-ratelimit
name: {{ include "private-gpt.fullname" . }}-conversation-api-ratelimit-auth
labels:
{{- include "private-gpt.labels" . | nindent 4 }}
app.kubernetes.io/component: conversation-api
spec:
rateLimit:
average: 1
burst: 5
period: 1s
sourceCriterion:
requestHeaderName: Authorization
---
apiVersion: traefik.containo.us/v1alpha1
kind: Middleware
Expand Down
99 changes: 67 additions & 32 deletions src/conversation-api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,14 @@
from fastapi import FastAPI, HTTPException, status, Request, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from models.conversation import (
GetConversationModel,
ListConversationsModel,
StoredConversationModel,
)
from models.message import MessageModel, MessageRole
from models.conversation import (GetConversationModel, ListConversationsModel, StoredConversationModel)
from models.message import MessageModel, MessageRole, StoredMessageModel
from models.prompt import StoredPromptModel, ListPromptsModel
from models.search import SearchModel
from models.user import UserModel
from persistence.qdrant import QdrantSearch
from persistence.redis import RedisStore, RedisStream, STREAM_STOPWORD
from persistence.isearch import SearchImplementation
from persistence.istore import StoreImplementation
from persistence.istream import StreamImplementation
from sse_starlette.sse import EventSourceResponse
from tenacity import retry, stop_after_attempt, wait_random_exponential
from typing import Annotated, Dict, List, Optional
Expand All @@ -45,12 +42,35 @@
logger = build_logger(__name__)

###
# Init Redis
# Init persistence
###

store = RedisStore()
stream = RedisStream()
index = QdrantSearch(store)
store_impl = get_config("persistence", "store", StoreImplementation, required=True)
if store_impl == StoreImplementation.COSMOS:
logger.info("Using CosmosDB store")
from persistence.cosmos import CosmosStore
store = CosmosStore()
elif store_impl == StoreImplementation.REDIS:
logger.info("Using Redis store")
from persistence.redis import RedisStore
store = RedisStore()
else:
raise ValueError(f"Unknown store implementation: {store_impl}")

search_impl = get_config("persistence", "search", SearchImplementation, required=True)
if search_impl == SearchImplementation.QDRANT:
logger.info("Using Qdrant search")
from persistence.qdrant import QdrantSearch
index = QdrantSearch(store)
else:
raise ValueError(f"Unknown search implementation: {search_impl}")

stream_impl = get_config("persistence", "stream", StreamImplementation, required=True)
if stream_impl == StreamImplementation.REDIS:
logger.info("Using Redis stream")
from persistence.redis import RedisStream, STREAM_STOPWORD
stream = RedisStream()
else:
raise ValueError(f"Unknown stream implementation: {stream_impl}")

###
# Init OpenAI
Expand Down Expand Up @@ -299,15 +319,6 @@ async def message_post(
detail="Message is moderated",
)

message = MessageModel(
content=content,
created_at=datetime.now(),
id=uuid4(),
role=MessageRole.USER,
secret=secret,
token=uuid4(),
)

if conversation_id:
# Validate API schema
if prompt_id:
Expand All @@ -326,12 +337,24 @@ async def message_post(
detail="Conversation not found",
)

# Build message
message = StoredMessageModel(
content=content,
conversation_id=conversation_id,
created_at=datetime.now(),
id=uuid4(),
role=MessageRole.USER,
secret=secret,
token=uuid4(),
)

# Validate message length
tokens_nb = oai_tokens_nb(
message.content
+ "".join([m.content for m in store.message_list(conversation_id)]),
+ "".join([m.content for m in store.message_list(message.conversation_id)]),
OAI_GPT_MODEL,
)

logger.debug(f"{tokens_nb} tokens in the conversation")
if tokens_nb > OAI_GPT_MAX_TOKENS:
logger.info(f"Message ({tokens_nb}) too long for conversation")
Expand All @@ -341,8 +364,8 @@ async def message_post(
)

# Update conversation
store.message_set(message, conversation_id)
index.message_index(message, conversation_id, current_user.id)
store.message_set(message)
index.message_index(message, current_user.id)
conversation = store.conversation_get(conversation_id, current_user.id)
if not conversation:
logger.warn("ACID error: conversation not found after testing existence")
Expand All @@ -359,7 +382,7 @@ async def message_post(
)

# Validate message length
tokens_nb = oai_tokens_nb(message.content, OAI_GPT_MODEL)
tokens_nb = oai_tokens_nb(content, OAI_GPT_MODEL)
logger.debug(f"{tokens_nb} tokens in the conversation")
if tokens_nb > OAI_GPT_MAX_TOKENS:
logger.info(f"Message ({tokens_nb}) too long for conversation")
Expand All @@ -368,16 +391,27 @@ async def message_post(
detail="Conversation history is too long",
)

# Create a new conversation
# Build conversation
conversation = StoredConversationModel(
created_at=datetime.now(),
id=uuid4(),
prompt=AI_PROMPTS[prompt_id] if prompt_id else None,
user_id=current_user.id,
)
store.conversation_set(conversation)
store.message_set(message, conversation.id)
index.message_index(message, conversation.id, current_user.id)

# Build message
message = StoredMessageModel(
content=content,
conversation_id=conversation.id,
created_at=datetime.now(),
id=uuid4(),
role=MessageRole.USER,
secret=secret,
token=uuid4(),
)
store.message_set(message)
index.message_index(message, current_user.id)

messages = store.message_list(conversation.id)

Expand Down Expand Up @@ -486,15 +520,16 @@ def completion_from_conversation(
content_full += content

# First, store the updated conversation in Redis
res_message = MessageModel(
res_message = StoredMessageModel(
content=content_full,
conversation_id=conversation.id,
created_at=datetime.now(),
id=uuid4(),
role=MessageRole.ASSISTANT,
secret=last_message.secret,
)
store.message_set(res_message, conversation.id)
index.message_index(res_message, conversation.id, current_user.id)
store.message_set(res_message)
index.message_index(res_message, current_user.id)

# Then, send the end of stream message
stream.push(STREAM_STOPWORD, last_message.token)
Expand Down
11 changes: 10 additions & 1 deletion src/conversation-api/models/message.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel
from typing import Optional
from uuid import UUID


Expand All @@ -20,7 +20,16 @@ class MessageModel(BaseModel):
token: Optional[UUID] = None


class StoredMessageModel(MessageModel):
conversation_id: UUID


class IndexMessageModel(BaseModel):
"""
Storing the message in a separate collection allows us to query for messages. It does not contain the message content, but only the metadata required to query for the message.
The absence of content and created_at is intentional. We don't want to store PII in the index. As this, we are not forced to apply a TTL to the index nor secure too much the DB.
"""
conversation_id: UUID
id: UUID
user_id: UUID
2 changes: 1 addition & 1 deletion src/conversation-api/models/prompt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import List
from pydantic import BaseModel
from typing import List
from uuid import UUID


Expand Down
115 changes: 115 additions & 0 deletions src/conversation-api/persistence/cosmos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Import utils
from utils import build_logger, get_config

# Import misc
from .istore import IStore
from azure.cosmos import CosmosClient, PartitionKey, ConsistencyLevel
from azure.cosmos.database import DatabaseProxy
from azure.cosmos.exceptions import CosmosHttpResponseError, CosmosResourceExistsError
from azure.identity import DefaultAzureCredential
from datetime import datetime
from models.conversation import StoredConversationModel, StoredConversationModel
from models.message import MessageModel, IndexMessageModel, StoredMessageModel
from models.user import UserModel
from typing import (Any, Dict, List, Union)
from uuid import UUID


logger = build_logger(__name__)
SECRET_TTL_SECS = 60 * 60 * 24 # 1 day

# Configuration
CONVERSATION_PREFIX = "conversation"
DB_URL = get_config("cosmos", "url", str, required=True)
DB_NAME = get_config("cosmos", "database", str, required=True)

# Cosmos DB Client
credential = DefaultAzureCredential()
client = CosmosClient(url=DB_URL, credential=credential, consistency_level=ConsistencyLevel.Session)
database = client.get_database_client(DB_NAME)
conversation_client = database.get_container_client("conversation")
message_client = database.get_container_client("message")
user_client = database.get_container_client("user")
logger.info(f'Connected to Cosmos DB at "{DB_URL}"')


class CosmosStore(IStore):
def user_get(self, user_external_id: str) -> Union[UserModel, None]:
query = f"SELECT * FROM c WHERE c.external_id = '{user_external_id}'"
items = user_client.query_items(query=query, partition_key="dummy")
try:
raw = next(items)
return UserModel(**raw)
except StopIteration:
return None

def user_set(self, user: UserModel) -> None:
user_client.upsert_item(body={
**self._sanitize_before_insert(user.dict()),
"dummy": "dummy",
})

def conversation_get(
self, conversation_id: UUID, user_id: UUID
) -> Union[StoredConversationModel, None]:
try:
raw = conversation_client.read_item(
item=str(conversation_id), partition_key=str(user_id)
)
return StoredConversationModel(**raw)
except CosmosHttpResponseError:
return None

def conversation_exists(self, conversation_id: UUID, user_id: UUID) -> bool:
return self.conversation_get(conversation_id, user_id) is not None

def conversation_set(self, conversation: StoredConversationModel) -> None:
conversation_client.upsert_item(body=self._sanitize_before_insert(conversation.dict()))

def conversation_list(self, user_id: UUID) -> List[StoredConversationModel]:
query = f"SELECT * FROM c WHERE c.user_id = '{user_id}'"
items = conversation_client.query_items(query=query, enable_cross_partition_query=True)
return [StoredConversationModel(**item) for item in items]

def message_get(
self, message_id: UUID, conversation_id: UUID
) -> Union[MessageModel, None]:
try:
raw = message_client.read_item(item=str(message_id), partition_key=str(conversation_id))
return MessageModel(**raw)
except CosmosHttpResponseError:
return None

def message_get_index(
self, message_indexs: List[IndexMessageModel]
) -> List[MessageModel]:
messages = []
for message_index in message_indexs:
try:
raw = message_client.read_item(
item=str(message_index.id), partition_key=str(message_index.conversation_id)
)
messages.append(MessageModel(**raw))
except CosmosHttpResponseError:
pass
return messages

def message_set(self, message: StoredMessageModel) -> None:
expiry = SECRET_TTL_SECS if message.secret else None
message_client.upsert_item(body={
**self._sanitize_before_insert(message.dict()),
"_ts": expiry, # TTL in seconds
})

def message_list(self, conversation_id: UUID) -> List[MessageModel]:
query = f"SELECT * FROM c WHERE c.conversation_id = '{conversation_id}'"
items = message_client.query_items(query=query, enable_cross_partition_query=True)
return [MessageModel(**item) for item in items]

def _sanitize_before_insert(self, item: dict) -> Dict[str, Union[str, int, float, bool]]:
for key, value in item.items():
if isinstance(value, UUID):
item[key] = str(value)
elif isinstance(value, datetime):
item[key] = value.isoformat()
return item
Loading

0 comments on commit f69f6f9

Please sign in to comment.