Skip to content

Commit

Permalink
Added thread tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cszsol committed Nov 19, 2024
1 parent 0001eee commit d5b711d
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 32 deletions.
35 changes: 3 additions & 32 deletions swarm_copy/app/routers/threads.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Threads CRUDs."""

import json
import logging
from typing import Annotated

Expand All @@ -12,8 +11,8 @@
from swarm_copy.app.app_utils import validate_project
from swarm_copy.app.config import Settings
from swarm_copy.app.database.db_utils import get_thread
from swarm_copy.app.database.schemas import MessagesRead, ThreadsRead, ThreadUpdate
from swarm_copy.app.database.sql_schemas import Entity, Messages, Threads
from swarm_copy.app.database.schemas import ThreadsRead, ThreadUpdate
from swarm_copy.app.database.sql_schemas import Threads
from swarm_copy.app.dependencies import (
get_httpx_client,
get_kg_token,
Expand Down Expand Up @@ -57,6 +56,7 @@ async def create_thread(
await session.commit()
await session.refresh(new_thread)

logger.error(f"thread_id: {new_thread.thread_id}")
return ThreadsRead(**new_thread.__dict__)


Expand All @@ -73,35 +73,6 @@ async def get_threads(
return [ThreadsRead(**thread.__dict__) for thread in threads]


@router.get("/{thread_id}")
async def get_messages(
session: Annotated[AsyncSession, Depends(get_session)],
_: Annotated[Threads, Depends(get_thread)], # to check if thread exist
thread_id: str,
) -> list[MessagesRead]:
"""Get all messages of the thread."""
messages_result = await session.execute(
select(Messages)
.where(
Messages.thread_id == thread_id,
Messages.entity.in_([Entity.USER, Entity.AI_MESSAGE]),
)
.order_by(Messages.order)
)
db_messages = messages_result.scalars().all()

messages = []
for msg in db_messages:
messages.append(
MessagesRead(
msg_content=json.loads(msg.content)["content"],
**msg.__dict__,
)
)

return messages


@router.patch("/{thread_id}")
async def update_thread_title(
session: Annotated[AsyncSession, Depends(get_session)],
Expand Down
122 changes: 122 additions & 0 deletions swarm_copy_tests/app/routers/test_threads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from unittest.mock import AsyncMock, Mock, patch

import pytest

from swarm_copy.app.database.sql_schemas import utc_now
from swarm_copy.app.routers.threads import get_threads, update_thread_title, delete_thread


@pytest.mark.asyncio
async def test_create_thread(app_client, settings):
mock_validate_project = AsyncMock()
mock_session = AsyncMock()
user_id = "user_id"
token = "token"
title = "title"
thread_id = "uuid"
project_id = "project_id"
virtual_lab_id = "virtual_lab_id"
creation_date = utc_now()
update_date = utc_now()
with patch("swarm_copy.app.app_utils.validate_project", mock_validate_project):
with patch('swarm_copy.app.database.sql_schemas.Threads', autospec=True) as mock_threads:
from swarm_copy.app.routers.threads import create_thread
mock_thread_instance = Mock(user_id=user_id,
title=title,
vlab_id=virtual_lab_id,
project_id=project_id,
thread_id=thread_id,
creation_date=creation_date,
update_date=update_date)
mock_threads.return_value = mock_thread_instance
await create_thread(app_client, settings,
token,
virtual_lab_id,
project_id,
mock_session,
user_id,
title)
assert mock_session.add.called
assert mock_session.commit.called
assert mock_session.refresh.called


@pytest.mark.asyncio
async def test_get_threads():
user_id = "user_id"
title = "title"
thread_id = "uuid"
project_id = "project_id"
virtual_lab_id = "virtual_lab_id"
creation_date = utc_now()
update_date = utc_now()
mock_threads = [
Mock(user_id=user_id,
title=title,
vlab_id=virtual_lab_id,
project_id=project_id,
thread_id=thread_id,
creation_date=creation_date,
update_date=update_date)
]
mock_session = AsyncMock()
scalars_mock = Mock()
scalars_mock.all.return_value = mock_threads
mock_thread_result = Mock()
mock_thread_result.scalars.return_value = scalars_mock
mock_session.execute.return_value = mock_thread_result
thread_reads = await get_threads(mock_session, user_id)
thread_read = thread_reads[0]
assert thread_read.thread_id == thread_id
assert thread_read.user_id == user_id
assert thread_read.vlab_id == virtual_lab_id
assert thread_read.project_id == project_id
assert thread_read.title == title
assert thread_read.creation_date == creation_date
assert thread_read.update_date == update_date


@pytest.mark.asyncio
async def test_update_thread_title():
user_id = "user_id"
title = "title"
thread_id = "uuid"
project_id = "project_id"
virtual_lab_id = "virtual_lab_id"
creation_date = utc_now()
update_date = utc_now()
mock_session = AsyncMock()
mock_thread_result = Mock()
mock_session.execute.return_value = mock_thread_result
mock_update_thread = Mock()
mock_update_thread.model_dump.return_value = {
"user_id": user_id,
"title": title,
"vlab_id": virtual_lab_id,
"project_id": project_id,
"thread_id": thread_id,
"creation_date": creation_date,
"update_date": update_date
}
mock_thread = Mock()
thread_read = await update_thread_title(mock_session, mock_update_thread, mock_thread)
assert mock_session.commit.called
assert mock_session.refresh.called
assert thread_read.thread_id == thread_id
assert thread_read.user_id == user_id
assert thread_read.vlab_id == virtual_lab_id
assert thread_read.project_id == project_id
assert thread_read.title == title
assert thread_read.creation_date == creation_date
assert thread_read.update_date == update_date


@pytest.mark.asyncio
async def test_delete_thread():
mock_session = AsyncMock()
mock_thread_result = Mock()
mock_session.execute.return_value = mock_thread_result
mock_thread = Mock()
await delete_thread(mock_session, mock_thread)
assert mock_session.delete.called
assert mock_session.commit.called
21 changes: 21 additions & 0 deletions swarm_copy_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,27 @@
from neuroagent.tools import GetMorphoTool


@pytest.fixture(name="settings")
def settings():
return Settings(
tools={
"literature": {
"url": "fake_literature_url",
},
},
knowledge_graph={
"base_url": "https://fake_url/api/nexus/v1",
},
openai={
"token": "fake_token",
},
keycloak={
"username": "fake_username",
"password": "fake_password",
},
)


@pytest.fixture(name="app_client")
def client_fixture():
"""Get client and clear app dependency_overrides."""
Expand Down

0 comments on commit d5b711d

Please sign in to comment.