From c1a5ccc1ed5f73d4f4b6bae02f311e82b0ccc3a8 Mon Sep 17 00:00:00 2001 From: scosman Date: Wed, 2 Oct 2024 16:49:50 -0400 Subject: [PATCH] Add tests, and fix issue (some exceptions bypass handlers) --- .../studio/kiln_studio/provider_management.py | 55 +++++----- .../kiln_studio/test_provider_management.py | 101 ++++++++++++++++++ 2 files changed, 128 insertions(+), 28 deletions(-) create mode 100644 libs/studio/kiln_studio/test_provider_management.py diff --git a/libs/studio/kiln_studio/provider_management.py b/libs/studio/kiln_studio/provider_management.py index b8a32bc..f392eb9 100644 --- a/libs/studio/kiln_studio/provider_management.py +++ b/libs/studio/kiln_studio/provider_management.py @@ -73,38 +73,37 @@ async def connect_api_key(payload: dict): content={"message": f"Provider {provider} not supported"}, ) - async def connect_openai(key: str): - try: - headers = { - "Authorization": f"Bearer {key}", - "Content-Type": "application/json", - } - response = requests.get("https://api.openai.com/v1/models", headers=headers) - # 401 def means invalid API key, so special case it - if response.status_code == 401: - return JSONResponse( - status_code=401, - content={ - "message": "Failed to connect to OpenAI. Invalid API key." - }, - ) +async def connect_openai(key: str): + try: + headers = { + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + } + response = requests.get("https://api.openai.com/v1/models", headers=headers) - # Any non-200 status code is an error - response.raise_for_status() - # If the request is successful, the function will continue - except requests.RequestException as e: + # 401 def means invalid API key, so special case it + if response.status_code == 401: return JSONResponse( - status_code=400, - content={ - "message": f"Failed to connect to OpenAI. Likely invalid API key. Error: {str(e)}" - }, + status_code=401, + content={"message": "Failed to connect to OpenAI. Invalid API key."}, ) - # It worked! Save the key and return success - Config.shared().open_ai_api_key = key - + # Any non-200 status code is an error + response.raise_for_status() + # If the request is successful, the function will continue + except Exception as e: return JSONResponse( - status_code=200, - content={"message": "Connected to OpenAI"}, + status_code=400, + content={ + "message": f"Failed to connect to OpenAI. Likely invalid API key. Error: {str(e)}" + }, ) + + # It worked! Save the key and return success + Config.shared().open_ai_api_key = key + + return JSONResponse( + status_code=200, + content={"message": "Connected to OpenAI"}, + ) diff --git a/libs/studio/kiln_studio/test_provider_management.py b/libs/studio/kiln_studio/test_provider_management.py new file mode 100644 index 0000000..c261163 --- /dev/null +++ b/libs/studio/kiln_studio/test_provider_management.py @@ -0,0 +1,101 @@ +from unittest.mock import MagicMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from libs.studio.kiln_studio.provider_management import connect_provider_management + + +@pytest.fixture +def app(): + app = FastAPI() + connect_provider_management(app) + return app + + +@pytest.fixture +def client(app): + return TestClient(app) + + +def test_connect_api_key_invalid_payload(client): + response = client.post( + "/provider/connect_api_key", json={"provider": "openai", "key_data": "invalid"} + ) + assert response.status_code == 400 + assert response.json() == {"message": "Invalid key_data or provider"} + + +def test_connect_api_key_unsupported_provider(client): + response = client.post( + "/provider/connect_api_key", + json={"provider": "unsupported", "key_data": {"API Key": "test"}}, + ) + assert response.status_code == 400 + assert response.json() == {"message": "Provider unsupported not supported"} + + +@patch("libs.studio.kiln_studio.provider_management.connect_openai") +def test_connect_api_key_openai_success(mock_connect_openai, client): + mock_connect_openai.return_value = {"message": "Connected to OpenAI"} + response = client.post( + "/provider/connect_api_key", + json={"provider": "openai", "key_data": {"API Key": "test_key"}}, + ) + assert response.status_code == 200 + assert response.json() == {"message": "Connected to OpenAI"} + mock_connect_openai.assert_called_once_with("test_key") + + +@patch("libs.studio.kiln_studio.provider_management.requests.get") +@patch("libs.studio.kiln_studio.provider_management.Config.shared") +def test_connect_openai_success(mock_config_shared, mock_requests_get, client): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_requests_get.return_value = mock_response + + mock_config = MagicMock() + mock_config_shared.return_value = mock_config + + response = client.post( + "/provider/connect_api_key", + json={"provider": "openai", "key_data": {"API Key": "test_key"}}, + ) + + assert response.status_code == 200 + assert response.json() == {"message": "Connected to OpenAI"} + assert mock_config.open_ai_api_key == "test_key" + + +@patch("libs.studio.kiln_studio.provider_management.requests.get") +def test_connect_openai_invalid_key(mock_requests_get, client): + mock_response = MagicMock() + mock_response.status_code = 401 + mock_requests_get.return_value = mock_response + + response = client.post( + "/provider/connect_api_key", + json={"provider": "openai", "key_data": {"API Key": "invalid_key"}}, + ) + + assert response.status_code == 401 + assert response.json() == { + "message": "Failed to connect to OpenAI. Invalid API key." + } + + +@patch("libs.studio.kiln_studio.provider_management.requests.get") +def test_connect_openai_request_exception(mock_requests_get, client): + mock_requests_get.side_effect = Exception("Test error") + + response = client.post( + "/provider/connect_api_key", + json={"provider": "openai", "key_data": {"API Key": "test_key"}}, + ) + + assert response.status_code == 400 + assert ( + "Failed to connect to OpenAI. Likely invalid API key. Error:" + in response.json()["message"] + )