Skip to content

Commit

Permalink
Add tests, and fix issue (some exceptions bypass handlers)
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Oct 2, 2024
1 parent 94962fe commit c1a5ccc
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 28 deletions.
55 changes: 27 additions & 28 deletions libs/studio/kiln_studio/provider_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,38 +73,37 @@ async def connect_api_key(payload: dict):
content={"message": f"Provider {provider} not supported"},
)

async def connect_openai(key: str):
try:
headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
}
response = requests.get("https://api.openai.com/v1/models", headers=headers)

# 401 def means invalid API key, so special case it
if response.status_code == 401:
return JSONResponse(
status_code=401,
content={
"message": "Failed to connect to OpenAI. Invalid API key."
},
)
async def connect_openai(key: str):
try:
headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
}
response = requests.get("https://api.openai.com/v1/models", headers=headers)

# Any non-200 status code is an error
response.raise_for_status()
# If the request is successful, the function will continue
except requests.RequestException as e:
# 401 def means invalid API key, so special case it
if response.status_code == 401:
return JSONResponse(
status_code=400,
content={
"message": f"Failed to connect to OpenAI. Likely invalid API key. Error: {str(e)}"
},
status_code=401,
content={"message": "Failed to connect to OpenAI. Invalid API key."},
)

# It worked! Save the key and return success
Config.shared().open_ai_api_key = key

# Any non-200 status code is an error
response.raise_for_status()
# If the request is successful, the function will continue
except Exception as e:
return JSONResponse(
status_code=200,
content={"message": "Connected to OpenAI"},
status_code=400,
content={
"message": f"Failed to connect to OpenAI. Likely invalid API key. Error: {str(e)}"
},
)

# It worked! Save the key and return success
Config.shared().open_ai_api_key = key

return JSONResponse(
status_code=200,
content={"message": "Connected to OpenAI"},
)
101 changes: 101 additions & 0 deletions libs/studio/kiln_studio/test_provider_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from unittest.mock import MagicMock, patch

import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient

from libs.studio.kiln_studio.provider_management import connect_provider_management


@pytest.fixture
def app():
app = FastAPI()
connect_provider_management(app)
return app


@pytest.fixture
def client(app):
return TestClient(app)


def test_connect_api_key_invalid_payload(client):
response = client.post(
"/provider/connect_api_key", json={"provider": "openai", "key_data": "invalid"}
)
assert response.status_code == 400
assert response.json() == {"message": "Invalid key_data or provider"}


def test_connect_api_key_unsupported_provider(client):
response = client.post(
"/provider/connect_api_key",
json={"provider": "unsupported", "key_data": {"API Key": "test"}},
)
assert response.status_code == 400
assert response.json() == {"message": "Provider unsupported not supported"}


@patch("libs.studio.kiln_studio.provider_management.connect_openai")
def test_connect_api_key_openai_success(mock_connect_openai, client):
mock_connect_openai.return_value = {"message": "Connected to OpenAI"}
response = client.post(
"/provider/connect_api_key",
json={"provider": "openai", "key_data": {"API Key": "test_key"}},
)
assert response.status_code == 200
assert response.json() == {"message": "Connected to OpenAI"}
mock_connect_openai.assert_called_once_with("test_key")


@patch("libs.studio.kiln_studio.provider_management.requests.get")
@patch("libs.studio.kiln_studio.provider_management.Config.shared")
def test_connect_openai_success(mock_config_shared, mock_requests_get, client):
mock_response = MagicMock()
mock_response.status_code = 200
mock_requests_get.return_value = mock_response

mock_config = MagicMock()
mock_config_shared.return_value = mock_config

response = client.post(
"/provider/connect_api_key",
json={"provider": "openai", "key_data": {"API Key": "test_key"}},
)

assert response.status_code == 200
assert response.json() == {"message": "Connected to OpenAI"}
assert mock_config.open_ai_api_key == "test_key"


@patch("libs.studio.kiln_studio.provider_management.requests.get")
def test_connect_openai_invalid_key(mock_requests_get, client):
mock_response = MagicMock()
mock_response.status_code = 401
mock_requests_get.return_value = mock_response

response = client.post(
"/provider/connect_api_key",
json={"provider": "openai", "key_data": {"API Key": "invalid_key"}},
)

assert response.status_code == 401
assert response.json() == {
"message": "Failed to connect to OpenAI. Invalid API key."
}


@patch("libs.studio.kiln_studio.provider_management.requests.get")
def test_connect_openai_request_exception(mock_requests_get, client):
mock_requests_get.side_effect = Exception("Test error")

response = client.post(
"/provider/connect_api_key",
json={"provider": "openai", "key_data": {"API Key": "test_key"}},
)

assert response.status_code == 400
assert (
"Failed to connect to OpenAI. Likely invalid API key. Error:"
in response.json()["message"]
)

0 comments on commit c1a5ccc

Please sign in to comment.