Skip to content

Commit

Permalink
Added db_utils tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cszsol committed Nov 20, 2024
1 parent d5a2ccf commit 2ec4851
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Tool implementations without langchain or langgraph dependencies
- CRUDs.
- BlueNaas CRUD tools
- Unit tests for database

### Fixed
- Migrate LLM Evaluation logic to scripts and add tests
Expand Down
1 change: 1 addition & 0 deletions swarm_copy_tests/app/database/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Unit tests for database."""
97 changes: 97 additions & 0 deletions swarm_copy_tests/app/database/test_db_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import json
from unittest.mock import AsyncMock, Mock, patch

import pytest
from fastapi import HTTPException

from swarm_copy.app.database.db_utils import get_thread, save_history, get_history
from swarm_copy.app.database.sql_schemas import Entity, Messages, Threads


@pytest.mark.asyncio
async def test_get_thread():
user_id = "0"
thread_id = "0"
mock_thread_result = Mock()
mock_scalars_return = Mock()
mock_thread = Mock()
mock_scalars_return.one_or_none.return_value = mock_thread
mock_thread_result.scalars.return_value = mock_scalars_return
mock_session = AsyncMock()
mock_session.execute.return_value = mock_thread_result
result = await get_thread(user_id=user_id, thread_id=thread_id, session=mock_session)
assert result == mock_thread


@pytest.mark.asyncio
async def test_get_thread_exception():
user_id = "0"
thread_id = "0"
mock_thread_result = Mock()
mock_scalars_return = Mock()
mock_scalars_return.one_or_none.return_value = None
mock_thread_result.scalars.return_value = mock_scalars_return
mock_session = AsyncMock()
mock_session.execute.return_value = mock_thread_result
with pytest.raises(HTTPException):
await get_thread(user_id=user_id, thread_id=thread_id, session=mock_session)


@pytest.mark.parametrize("message_role,expected_entity,content", [
('user', Entity.USER, False),
('tool', Entity.TOOL, False),
('assistant', Entity.AI_MESSAGE, True),
('assistant', Entity.AI_TOOL, False)
])
@pytest.mark.asyncio
async def test_save_history(message_role, expected_entity, content):
history = [{"role": message_role, "content": content}]
user_id, thread_id, offset = "test_user", "test_thread", 0

mock_session = AsyncMock()
mock_thread = AsyncMock()

async def mock_get_thread(**kwargs):
return mock_thread

with patch("swarm_copy.app.database.db_utils.get_thread", mock_get_thread):
await save_history(history, user_id, thread_id, offset, mock_session)

assert mock_session.add.called

called_with_param = mock_session.add.call_args[0][0]
assert isinstance(called_with_param, Messages)
assert called_with_param.order == 0
assert called_with_param.thread_id == thread_id
assert called_with_param.entity == expected_entity
assert called_with_param.content == json.dumps(history[0])

assert mock_session.commit.called


@pytest.mark.asyncio
async def test_save_history_exception():
history = [{"role": "bad role", "content": None}]
user_id, thread_id, offset = "test_user", "test_thread", 0

mock_session = AsyncMock()

with pytest.raises(HTTPException):
await save_history(history, user_id, thread_id, offset, mock_session)


@pytest.mark.asyncio
async def test_get_history():
msg1 = Mock()
msg1.content = json.dumps("message1")
msg2 = Mock()
msg2.content = json.dumps("message2")
mock_thread = AsyncMock()
messages = [msg1, msg2]

async def mock_messages():
return messages

mock_thread.awaitable_attrs.messages = mock_messages()
results = await get_history(mock_thread)
assert results == [json.loads(msg.content) for msg in messages]

0 comments on commit 2ec4851

Please sign in to comment.