From 2ec48515d8eb34084d80fc41444da6148a1c0c20 Mon Sep 17 00:00:00 2001 From: cszsolnai Date: Wed, 20 Nov 2024 08:38:30 +0100 Subject: [PATCH] Added db_utils tests --- CHANGELOG.md | 1 + swarm_copy_tests/app/database/__init__.py | 1 + .../app/database/test_db_utils.py | 97 +++++++++++++++++++ 3 files changed, 99 insertions(+) create mode 100644 swarm_copy_tests/app/database/__init__.py create mode 100644 swarm_copy_tests/app/database/test_db_utils.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 52955bc..bceac7e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Tool implementations without langchain or langgraph dependencies - CRUDs. - BlueNaas CRUD tools +- Unit tests for database ### Fixed - Migrate LLM Evaluation logic to scripts and add tests diff --git a/swarm_copy_tests/app/database/__init__.py b/swarm_copy_tests/app/database/__init__.py new file mode 100644 index 0000000..8ce3e8d --- /dev/null +++ b/swarm_copy_tests/app/database/__init__.py @@ -0,0 +1 @@ +"""Unit tests for database.""" diff --git a/swarm_copy_tests/app/database/test_db_utils.py b/swarm_copy_tests/app/database/test_db_utils.py new file mode 100644 index 0000000..57fbea5 --- /dev/null +++ b/swarm_copy_tests/app/database/test_db_utils.py @@ -0,0 +1,97 @@ +import json +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from fastapi import HTTPException + +from swarm_copy.app.database.db_utils import get_thread, save_history, get_history +from swarm_copy.app.database.sql_schemas import Entity, Messages, Threads + + +@pytest.mark.asyncio +async def test_get_thread(): + user_id = "0" + thread_id = "0" + mock_thread_result = Mock() + mock_scalars_return = Mock() + mock_thread = Mock() + mock_scalars_return.one_or_none.return_value = mock_thread + mock_thread_result.scalars.return_value = mock_scalars_return + mock_session = AsyncMock() + mock_session.execute.return_value = mock_thread_result + result = await get_thread(user_id=user_id, thread_id=thread_id, session=mock_session) + assert result == mock_thread + + +@pytest.mark.asyncio +async def test_get_thread_exception(): + user_id = "0" + thread_id = "0" + mock_thread_result = Mock() + mock_scalars_return = Mock() + mock_scalars_return.one_or_none.return_value = None + mock_thread_result.scalars.return_value = mock_scalars_return + mock_session = AsyncMock() + mock_session.execute.return_value = mock_thread_result + with pytest.raises(HTTPException): + await get_thread(user_id=user_id, thread_id=thread_id, session=mock_session) + + +@pytest.mark.parametrize("message_role,expected_entity,content", [ + ('user', Entity.USER, False), + ('tool', Entity.TOOL, False), + ('assistant', Entity.AI_MESSAGE, True), + ('assistant', Entity.AI_TOOL, False) +]) +@pytest.mark.asyncio +async def test_save_history(message_role, expected_entity, content): + history = [{"role": message_role, "content": content}] + user_id, thread_id, offset = "test_user", "test_thread", 0 + + mock_session = AsyncMock() + mock_thread = AsyncMock() + + async def mock_get_thread(**kwargs): + return mock_thread + + with patch("swarm_copy.app.database.db_utils.get_thread", mock_get_thread): + await save_history(history, user_id, thread_id, offset, mock_session) + + assert mock_session.add.called + + called_with_param = mock_session.add.call_args[0][0] + assert isinstance(called_with_param, Messages) + assert called_with_param.order == 0 + assert called_with_param.thread_id == thread_id + assert called_with_param.entity == expected_entity + assert called_with_param.content == json.dumps(history[0]) + + assert mock_session.commit.called + + +@pytest.mark.asyncio +async def test_save_history_exception(): + history = [{"role": "bad role", "content": None}] + user_id, thread_id, offset = "test_user", "test_thread", 0 + + mock_session = AsyncMock() + + with pytest.raises(HTTPException): + await save_history(history, user_id, thread_id, offset, mock_session) + + +@pytest.mark.asyncio +async def test_get_history(): + msg1 = Mock() + msg1.content = json.dumps("message1") + msg2 = Mock() + msg2.content = json.dumps("message2") + mock_thread = AsyncMock() + messages = [msg1, msg2] + + async def mock_messages(): + return messages + + mock_thread.awaitable_attrs.messages = mock_messages() + results = await get_history(mock_thread) + assert results == [json.loads(msg.content) for msg in messages]