Skip to content

Commit

Permalink
Fix: Search results parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
clemlesne committed Jul 26, 2023
1 parent 759bba9 commit e197ccf
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/conversation-api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ async def loop_func() -> bool:
)
async def message_search(
q: str, current_user: Annotated[UserModel, Depends(get_current_user)]
) -> SearchModel:
) -> SearchModel[MessageModel]:
return await index.message_search(q, current_user.id, 25)


Expand Down
7 changes: 4 additions & 3 deletions src/conversation-api/models/search.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from pydantic import BaseModel
from pydantic.generics import GenericModel
from typing import List, TypeVar, Generic


T = TypeVar("T")
T = TypeVar("T", bound=BaseModel)


class SearchAnswerModel(BaseModel, Generic[T]):
class SearchAnswerModel(GenericModel, Generic[T]):
data: T
score: float

Expand All @@ -15,7 +16,7 @@ class SearchStatsModel(BaseModel):
total: int


class SearchModel(BaseModel, Generic[T]):
class SearchModel(GenericModel, Generic[T]):
answers: List[SearchAnswerModel[T]]
query: str
stats: SearchStatsModel
24 changes: 20 additions & 4 deletions src/conversation-api/persistence/cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,18 @@ async def conversation_list(self, user_id: UUID) -> List[StoredConversationModel
query = (
f"SELECT * FROM c WHERE c.user_id = '{user_id}' ORDER BY c.created_at DESC"
)
items = conversation_client.query_items(
raws = conversation_client.query_items(
query=query, enable_cross_partition_query=True
)
return [StoredConversationModel(**item) for item in items]
conversations = []
for raw in raws:
if raw is None:
continue
try:
conversations.append(StoredConversationModel(**raw))
except Exception:
logger.warn("Error parsing conversation", exc_info=True)
return conversations

async def message_get(
self, message_id: UUID, conversation_id: UUID
Expand Down Expand Up @@ -134,10 +142,18 @@ async def message_set(self, message: StoredMessageModel) -> None:

async def message_list(self, conversation_id: UUID) -> List[MessageModel]:
query = f"SELECT * FROM c WHERE c.conversation_id = '{conversation_id}' ORDER BY c.created_at ASC"
items = message_client.query_items(
raws = message_client.query_items(
query=query, enable_cross_partition_query=True
)
return [MessageModel(**item) for item in items]
items = []
for raw in raws:
if raw is None:
continue
try:
items.append(MessageModel(**raw))
except Exception:
logger.warn("Error parsing message", exc_info=True)
return items

async def usage_set(self, usage: UsageModel) -> None:
logger.debug(f'Usage set "{usage.id}"')
Expand Down

0 comments on commit e197ccf

Please sign in to comment.