From d3284d0684447ff38a5b6150407d640df1779df1 Mon Sep 17 00:00:00 2001 From: scosman Date: Wed, 2 Oct 2024 09:10:42 -0400 Subject: [PATCH] More decomposing app. Missing file from last commit Actual tests for settings --- libs/studio/kiln_studio/server.py | 32 +---------- libs/studio/kiln_studio/settings.py | 40 +++++++++++++ libs/studio/kiln_studio/test_settings.py | 73 ++++++++++++++++++++++++ libs/studio/kiln_studio/webhost.py | 44 ++++++++++++++ 4 files changed, 159 insertions(+), 30 deletions(-) create mode 100644 libs/studio/kiln_studio/settings.py create mode 100644 libs/studio/kiln_studio/test_settings.py create mode 100644 libs/studio/kiln_studio/webhost.py diff --git a/libs/studio/kiln_studio/server.py b/libs/studio/kiln_studio/server.py index 5481334..e45035e 100644 --- a/libs/studio/kiln_studio/server.py +++ b/libs/studio/kiln_studio/server.py @@ -9,6 +9,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse +from .settings import connect_settings from .webhost import connect_webhost @@ -27,19 +28,6 @@ def make_app(): def ping(): return "pong" - def settings_path(create=True): - settings_dir = os.path.join(Path.home(), ".kiln_ai") - if create and not os.path.exists(settings_dir): - os.makedirs(settings_dir) - return os.path.join(settings_dir, "settings.yaml") - - def load_settings(): - if not os.path.isfile(settings_path(create=False)): - return {} - with open(settings_path(), "r") as f: - settings = yaml.safe_load(f.read()) - return settings - @app.post("/provider/ollama/connect") def connect_ollama(): # Tags is a list of Ollama models. Proves Ollama is running, and models are available. @@ -71,23 +59,7 @@ def connect_ollama(): content={"message": "Ollama not connected, or no Ollama models installed."}, ) - @app.post("/setting") - def update_settings(new_settings: dict[str, int | float | str | bool]): - settings = load_settings() - settings.update(new_settings) - with open(settings_path(), "w") as f: - f.write(yaml.dump(settings)) - return {"message": "Settings updated"} - - @app.get("/settings") - def read_settings(): - settings = load_settings() - return settings - - @app.get("/items/{item_id}") - def read_item(item_id: int, q: Union[str, None] = None): - return {"item_id": item_id, "q": q} - + connect_settings(app) connect_webhost(app) return app diff --git a/libs/studio/kiln_studio/settings.py b/libs/studio/kiln_studio/settings.py new file mode 100644 index 0000000..2e6a5f9 --- /dev/null +++ b/libs/studio/kiln_studio/settings.py @@ -0,0 +1,40 @@ +import os +from pathlib import Path +from typing import Union + +import yaml +from fastapi import FastAPI + + +def settings_path(create=True): + settings_dir = os.path.join(Path.home(), ".kiln_ai") + if create and not os.path.exists(settings_dir): + os.makedirs(settings_dir) + return os.path.join(settings_dir, "settings.yaml") + + +def load_settings(): + if not os.path.isfile(settings_path(create=False)): + return {} + with open(settings_path(), "r") as f: + settings = yaml.safe_load(f.read()) + return settings + + +def connect_settings(app: FastAPI): + @app.post("/setting") + def update_settings(new_settings: dict[str, int | float | str | bool]): + settings = load_settings() + settings.update(new_settings) + with open(settings_path(), "w") as f: + f.write(yaml.dump(settings)) + return {"message": "Settings updated"} + + @app.get("/settings") + def read_settings(): + settings = load_settings() + return settings + + @app.get("/items/{item_id}") + def read_item(item_id: int): + return {"item_id": item_id} diff --git a/libs/studio/kiln_studio/test_settings.py b/libs/studio/kiln_studio/test_settings.py new file mode 100644 index 0000000..655ba8e --- /dev/null +++ b/libs/studio/kiln_studio/test_settings.py @@ -0,0 +1,73 @@ +import os + +import pytest +import yaml +from fastapi import FastAPI +from fastapi.testclient import TestClient +from kiln_studio.settings import connect_settings, load_settings, settings_path + + +@pytest.fixture +def temp_home(tmp_path, monkeypatch): + monkeypatch.setattr(os.path, "expanduser", lambda x: str(tmp_path)) + return tmp_path + + +@pytest.fixture +def app(): + app = FastAPI() + connect_settings(app) + return app + + +@pytest.fixture +def client(app): + return TestClient(app) + + +def test_settings_path(temp_home): + expected_path = os.path.join(temp_home, ".kiln_ai", "settings.yaml") + assert settings_path() == expected_path + assert os.path.exists(os.path.dirname(expected_path)) + + +def test_load_settings_empty(temp_home): + assert load_settings() == {} + + +def test_load_settings_existing(temp_home): + settings_file = settings_path() + os.makedirs(os.path.dirname(settings_file), exist_ok=True) + test_settings = {"key": "value"} + with open(settings_file, "w") as f: + yaml.dump(test_settings, f) + + assert load_settings() == test_settings + + +def test_update_settings(client, temp_home): + new_settings = {"test_key": "test_value"} + response = client.post("/setting", json=new_settings) + assert response.status_code == 200 + assert response.json() == {"message": "Settings updated"} + + # Verify the settings were actually updated + with open(settings_path(), "r") as f: + saved_settings = yaml.safe_load(f) + assert saved_settings == new_settings + + +def test_read_settings(client, temp_home): + test_settings = {"key1": "value1", "key2": 42} + with open(settings_path(), "w") as f: + yaml.dump(test_settings, f) + + response = client.get("/settings") + assert response.status_code == 200 + assert response.json() == test_settings + + +def test_read_item(client): + response = client.get("/items/42") + assert response.status_code == 200 + assert response.json() == {"item_id": 42} diff --git a/libs/studio/kiln_studio/webhost.py b/libs/studio/kiln_studio/webhost.py new file mode 100644 index 0000000..ff136d0 --- /dev/null +++ b/libs/studio/kiln_studio/webhost.py @@ -0,0 +1,44 @@ +import os +import sys + +from fastapi import FastAPI +from fastapi.responses import FileResponse +from fastapi.staticfiles import StaticFiles + + +def studio_path(): + try: + # pyinstaller path + base_path = sys._MEIPASS # type: ignore + return os.path.join(base_path, "./web_ui/build") + except Exception: + base_path = os.path.join(os.path.dirname(__file__), "..") + return os.path.join(base_path, "../../app/web_ui/build") + + +# File server that maps /foo/bar to /foo/bar.html (Starlette StaticFiles only does index.html) +class HTMLStaticFiles(StaticFiles): + async def get_response(self, path: str, scope): + try: + response = await super().get_response(path, scope) + if response.status_code != 404: + return response + except Exception as e: + # catching HTTPException explicitly not working for some reason + if getattr(e, "status_code", None) != 404: + # Don't raise on 404, fall through to return the .html version + raise e + # Try the .html version of the file if the .html version exists, for 404s + return await super().get_response(f"{path}.html", scope) + + +def connect_webhost(app: FastAPI): + # Ensure studio_path exists (test servers don't necessarily create it) + os.makedirs(studio_path(), exist_ok=True) + # Serves the web UI at root + app.mount("/", HTMLStaticFiles(directory=studio_path(), html=True), name="studio") + + # add pretty 404s + @app.exception_handler(404) + def not_found_exception_handler(request, exc): + return FileResponse(os.path.join(studio_path(), "404.html"), status_code=404)