-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Missing file from last commit Actual tests for settings
- Loading branch information
Showing
4 changed files
with
159 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import os | ||
from pathlib import Path | ||
from typing import Union | ||
|
||
import yaml | ||
from fastapi import FastAPI | ||
|
||
|
||
def settings_path(create=True): | ||
settings_dir = os.path.join(Path.home(), ".kiln_ai") | ||
if create and not os.path.exists(settings_dir): | ||
os.makedirs(settings_dir) | ||
return os.path.join(settings_dir, "settings.yaml") | ||
|
||
|
||
def load_settings(): | ||
if not os.path.isfile(settings_path(create=False)): | ||
return {} | ||
with open(settings_path(), "r") as f: | ||
settings = yaml.safe_load(f.read()) | ||
return settings | ||
|
||
|
||
def connect_settings(app: FastAPI): | ||
@app.post("/setting") | ||
def update_settings(new_settings: dict[str, int | float | str | bool]): | ||
settings = load_settings() | ||
settings.update(new_settings) | ||
with open(settings_path(), "w") as f: | ||
f.write(yaml.dump(settings)) | ||
return {"message": "Settings updated"} | ||
|
||
@app.get("/settings") | ||
def read_settings(): | ||
settings = load_settings() | ||
return settings | ||
|
||
@app.get("/items/{item_id}") | ||
def read_item(item_id: int): | ||
return {"item_id": item_id} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import os | ||
|
||
import pytest | ||
import yaml | ||
from fastapi import FastAPI | ||
from fastapi.testclient import TestClient | ||
from kiln_studio.settings import connect_settings, load_settings, settings_path | ||
|
||
|
||
@pytest.fixture | ||
def temp_home(tmp_path, monkeypatch): | ||
monkeypatch.setattr(os.path, "expanduser", lambda x: str(tmp_path)) | ||
return tmp_path | ||
|
||
|
||
@pytest.fixture | ||
def app(): | ||
app = FastAPI() | ||
connect_settings(app) | ||
return app | ||
|
||
|
||
@pytest.fixture | ||
def client(app): | ||
return TestClient(app) | ||
|
||
|
||
def test_settings_path(temp_home): | ||
expected_path = os.path.join(temp_home, ".kiln_ai", "settings.yaml") | ||
assert settings_path() == expected_path | ||
assert os.path.exists(os.path.dirname(expected_path)) | ||
|
||
|
||
def test_load_settings_empty(temp_home): | ||
assert load_settings() == {} | ||
|
||
|
||
def test_load_settings_existing(temp_home): | ||
settings_file = settings_path() | ||
os.makedirs(os.path.dirname(settings_file), exist_ok=True) | ||
test_settings = {"key": "value"} | ||
with open(settings_file, "w") as f: | ||
yaml.dump(test_settings, f) | ||
|
||
assert load_settings() == test_settings | ||
|
||
|
||
def test_update_settings(client, temp_home): | ||
new_settings = {"test_key": "test_value"} | ||
response = client.post("/setting", json=new_settings) | ||
assert response.status_code == 200 | ||
assert response.json() == {"message": "Settings updated"} | ||
|
||
# Verify the settings were actually updated | ||
with open(settings_path(), "r") as f: | ||
saved_settings = yaml.safe_load(f) | ||
assert saved_settings == new_settings | ||
|
||
|
||
def test_read_settings(client, temp_home): | ||
test_settings = {"key1": "value1", "key2": 42} | ||
with open(settings_path(), "w") as f: | ||
yaml.dump(test_settings, f) | ||
|
||
response = client.get("/settings") | ||
assert response.status_code == 200 | ||
assert response.json() == test_settings | ||
|
||
|
||
def test_read_item(client): | ||
response = client.get("/items/42") | ||
assert response.status_code == 200 | ||
assert response.json() == {"item_id": 42} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import os | ||
import sys | ||
|
||
from fastapi import FastAPI | ||
from fastapi.responses import FileResponse | ||
from fastapi.staticfiles import StaticFiles | ||
|
||
|
||
def studio_path(): | ||
try: | ||
# pyinstaller path | ||
base_path = sys._MEIPASS # type: ignore | ||
return os.path.join(base_path, "./web_ui/build") | ||
except Exception: | ||
base_path = os.path.join(os.path.dirname(__file__), "..") | ||
return os.path.join(base_path, "../../app/web_ui/build") | ||
|
||
|
||
# File server that maps /foo/bar to /foo/bar.html (Starlette StaticFiles only does index.html) | ||
class HTMLStaticFiles(StaticFiles): | ||
async def get_response(self, path: str, scope): | ||
try: | ||
response = await super().get_response(path, scope) | ||
if response.status_code != 404: | ||
return response | ||
except Exception as e: | ||
# catching HTTPException explicitly not working for some reason | ||
if getattr(e, "status_code", None) != 404: | ||
# Don't raise on 404, fall through to return the .html version | ||
raise e | ||
# Try the .html version of the file if the .html version exists, for 404s | ||
return await super().get_response(f"{path}.html", scope) | ||
|
||
|
||
def connect_webhost(app: FastAPI): | ||
# Ensure studio_path exists (test servers don't necessarily create it) | ||
os.makedirs(studio_path(), exist_ok=True) | ||
# Serves the web UI at root | ||
app.mount("/", HTMLStaticFiles(directory=studio_path(), html=True), name="studio") | ||
|
||
# add pretty 404s | ||
@app.exception_handler(404) | ||
def not_found_exception_handler(request, exc): | ||
return FileResponse(os.path.join(studio_path(), "404.html"), status_code=404) |