From a2624e3dcd8b018ab4b16ee39e1aacdca58ed37c Mon Sep 17 00:00:00 2001 From: Nicolas Frank <58003267+WonderPG@users.noreply.github.com> Date: Thu, 19 Dec 2024 17:35:36 +0100 Subject: [PATCH] Add tool CRUD tests (#59) --- CHANGELOG.md | 1 + swarm_copy/app/routers/tools.py | 2 +- swarm_copy_tests/app/routers/test_tools.py | 163 +++++++++++++++++++++ 3 files changed, 165 insertions(+), 1 deletion(-) create mode 100644 swarm_copy_tests/app/routers/test_tools.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ab5482..bed4956 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - app unit tests - Tests of AgentsRoutine. - Unit tests for database +- Tests for tool CRUD endpoints. ### Fixed - Migrate LLM Evaluation logic to scripts and add tests diff --git a/swarm_copy/app/routers/tools.py b/swarm_copy/app/routers/tools.py index bc4e06f..98389f7 100644 --- a/swarm_copy/app/routers/tools.py +++ b/swarm_copy/app/routers/tools.py @@ -106,6 +106,6 @@ async def get_tool_returns( for msg in tool_messages: msg_content = json.loads(msg.content) if msg_content.get("tool_call_id") == tool_call_id: - tool_output.append(msg_content) + tool_output.append(msg_content["content"]) return tool_output diff --git a/swarm_copy_tests/app/routers/test_tools.py b/swarm_copy_tests/app/routers/test_tools.py new file mode 100644 index 0000000..46a7ae8 --- /dev/null +++ b/swarm_copy_tests/app/routers/test_tools.py @@ -0,0 +1,163 @@ +"""Test of the tool router.""" + +import json + +import pytest + +from swarm_copy.agent_routine import Agent, AgentsRoutine +from swarm_copy.app.config import Settings +from swarm_copy.app.database.schemas import ToolCallSchema +from swarm_copy.app.dependencies import ( + get_agents_routine, + get_context_variables, + get_settings, + get_starting_agent, +) +from swarm_copy.app.main import app +from swarm_copy_tests.mock_client import create_mock_response + + +@pytest.mark.httpx_mock(can_send_already_matched_responses=True) +@pytest.mark.asyncio +async def test_get_tool_calls( + patch_required_env, + httpx_mock, + app_client, + db_connection, + mock_openai_client, + get_weather_tool, +): + routine = AgentsRoutine(client=mock_openai_client) + + mock_openai_client.set_sequential_responses( + [ + create_mock_response( + message={"role": "assistant", "content": ""}, + function_calls=[ + {"name": "get_weather", "args": {"location": "Geneva"}} + ], + ), + create_mock_response( + {"role": "assistant", "content": "sample response content"} + ), + ] + ) + agent = Agent(tools=[get_weather_tool]) + + app.dependency_overrides[get_agents_routine] = lambda: routine + app.dependency_overrides[get_starting_agent] = lambda: agent + test_settings = Settings( + db={"prefix": db_connection}, + ) + app.dependency_overrides[get_settings] = lambda: test_settings + httpx_mock.add_response( + url=f"{test_settings.virtual_lab.get_project_url}/test_vlab/projects/test_project" + ) + + with app_client as app_client: + wrong_response = app_client.get("/tools/test/1234") + assert wrong_response.status_code == 404 + assert wrong_response.json() == {"detail": {"detail": "Thread not found."}} + + # Create a thread + create_output = app_client.post( + "/threads/?virtual_lab_id=test_vlab&project_id=test_project" + ).json() + thread_id = create_output["thread_id"] + + # Fill the thread + app_client.post( + f"/qa/chat/{thread_id}", + json={"query": "This is my query"}, + params={"thread_id": thread_id}, + headers={"x-virtual-lab-id": "test_vlab", "x-project-id": "test_project"}, + ) + + tool_calls = app_client.get(f"/tools/{thread_id}/wrong_id") + assert tool_calls.status_code == 404 + assert tool_calls.json() == {"detail": {"detail": "Message not found."}} + + # Get the messages of the thread + messages = app_client.get(f"/threads/{thread_id}").json() + message_id = messages[-1]["message_id"] + tool_calls = app_client.get(f"/tools/{thread_id}/{message_id}").json() + + assert ( + tool_calls[0] + == ToolCallSchema( + tool_call_id="mock_tc_id", + name="get_weather", + arguments={"location": "Geneva"}, + ).model_dump() + ) + + +@pytest.mark.httpx_mock(can_send_already_matched_responses=True) +@pytest.mark.asyncio +async def test_get_tool_output( + patch_required_env, + app_client, + httpx_mock, + db_connection, + mock_openai_client, + agent_handoff_tool, +): + routine = AgentsRoutine(client=mock_openai_client) + + mock_openai_client.set_sequential_responses( + [ + create_mock_response( + message={"role": "assistant", "content": ""}, + function_calls=[{"name": "agent_handoff_tool", "args": {}}], + ), + create_mock_response( + {"role": "assistant", "content": "sample response content"} + ), + ] + ) + agent_1 = Agent(name="Test agent 1", tools=[agent_handoff_tool]) + agent_2 = Agent(name="Test agent 2", tools=[]) + + app.dependency_overrides[get_agents_routine] = lambda: routine + app.dependency_overrides[get_starting_agent] = lambda: agent_1 + app.dependency_overrides[get_context_variables] = lambda: {"to_agent": agent_2} + test_settings = Settings( + db={"prefix": db_connection}, + ) + app.dependency_overrides[get_settings] = lambda: test_settings + httpx_mock.add_response( + url=f"{test_settings.virtual_lab.get_project_url}/test_vlab/projects/test_project" + ) + + with app_client as app_client: + wrong_response = app_client.get("/tools/output/test/123") + assert wrong_response.status_code == 404 + assert wrong_response.json() == {"detail": {"detail": "Thread not found."}} + + # Create a thread + create_output = app_client.post( + "/threads/?virtual_lab_id=test_vlab&project_id=test_project" + ).json() + thread_id = create_output["thread_id"] + + # Fill the thread + app_client.post( + f"/qa/chat/{thread_id}", + json={"query": "This is my query"}, + params={"thread_id": thread_id}, + headers={"x-virtual-lab-id": "test_vlab", "x-project-id": "test_project"}, + ) + + tool_output = app_client.get(f"/tools/output/{thread_id}/123") + assert tool_output.status_code == 200 + assert tool_output.json() == [] + + # Get the messages of the thread + messages = app_client.get(f"/threads/{thread_id}").json() + message_id = messages[-1]["message_id"] + tool_calls = app_client.get(f"/tools/{thread_id}/{message_id}").json() + + tool_call_id = tool_calls[0]["tool_call_id"] + tool_output = app_client.get(f"/tools/output/{thread_id}/{tool_call_id}") + + assert tool_output.json() == [json.dumps({"assistant": agent_2.name})]