From c5d10250b21bd617af59a23d6e3540fba80e0dbd Mon Sep 17 00:00:00 2001 From: scosman Date: Tue, 1 Oct 2024 16:50:28 -0400 Subject: [PATCH] 404 support with pretty page. Start of improving my "app" testability, but mode coming next --- app/desktop/desktop_server.py | 2 +- app/desktop/test_desktop.py | 2 +- app/web_ui/svelte.config.js | 7 +- libs/studio/kiln_studio/server.py | 186 +++++++++++++------------ libs/studio/kiln_studio/test_server.py | 88 +++++++----- libs/studio/tests/test_import.py | 7 - 6 files changed, 153 insertions(+), 139 deletions(-) delete mode 100644 libs/studio/tests/test_import.py diff --git a/app/desktop/desktop_server.py b/app/desktop/desktop_server.py index 75dbe83..3c27065 100644 --- a/app/desktop/desktop_server.py +++ b/app/desktop/desktop_server.py @@ -8,7 +8,7 @@ def server_config(port=8757): return uvicorn.Config( - kiln_server.app, + kiln_server.make_app(), host="127.0.0.1", port=port, log_level="warning", diff --git a/app/desktop/test_desktop.py b/app/desktop/test_desktop.py index 74938f0..06096b1 100644 --- a/app/desktop/test_desktop.py +++ b/app/desktop/test_desktop.py @@ -11,5 +11,5 @@ def test_desktop_app_server(): config = desktop_server.server_config(port=port) uni_server = desktop_server.ThreadedServer(config=config) with uni_server.run_in_thread(): - r = requests.get("http://127.0.0.1:{}/".format(port)) + r = requests.get("http://127.0.0.1:{}/ping".format(port)) assert r.status_code == 200 diff --git a/app/web_ui/svelte.config.js b/app/web_ui/svelte.config.js index e8c5fec..4d2748f 100644 --- a/app/web_ui/svelte.config.js +++ b/app/web_ui/svelte.config.js @@ -4,7 +4,12 @@ import { vitePreprocess } from "@sveltejs/vite-plugin-svelte" /** @type {import('@sveltejs/kit').Config} */ const config = { kit: { - adapter: adapter(), + adapter: adapter({ + fallback: "404.html", + }), + }, + prerender: { + default: true, }, preprocess: vitePreprocess(), } diff --git a/libs/studio/kiln_studio/server.py b/libs/studio/kiln_studio/server.py index 859a0b5..5c2e2f7 100644 --- a/libs/studio/kiln_studio/server.py +++ b/libs/studio/kiln_studio/server.py @@ -8,13 +8,13 @@ import yaml from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse +from fastapi.responses import FileResponse, JSONResponse from fastapi.staticfiles import StaticFiles -# TODO would rather this get passed. This class shouldn't know about desktop def studio_path(): try: + # pyinstaller path base_path = sys._MEIPASS # type: ignore return os.path.join(base_path, "./web_ui/build") except Exception: @@ -22,108 +22,112 @@ def studio_path(): return os.path.join(base_path, "../../app/web_ui/build") -app = FastAPI() - -app.add_middleware( - CORSMiddleware, - allow_origin_regex=r"^https?://(localhost|127\.0\.0\.1)(:\d+)?$", - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - - -@app.get("/ping") -def ping(): - return "pong" - - -def settings_path(create=True): - settings_dir = os.path.join(Path.home(), ".kiln_ai") - if create and not os.path.exists(settings_dir): - os.makedirs(settings_dir) - return os.path.join(settings_dir, "settings.yaml") - - -def load_settings(): - if not os.path.isfile(settings_path(create=False)): - return {} - with open(settings_path(), "r") as f: - settings = yaml.safe_load(f.read()) - return settings +# File server that maps /foo/bar to /foo/bar.html (Starlette StaticFiles only does index.html) +class HTMLStaticFiles(StaticFiles): + async def get_response(self, path: str, scope): + try: + response = await super().get_response(path, scope) + if response.status_code != 404: + return response + except Exception as e: + # catching HTTPException explicitly not working for some reason + if getattr(e, "status_code", None) != 404: + # Don't raise on 404, fall through to return the .html version + raise e + # Try the .html version of the file if the .html version exists, for 404s + return await super().get_response(f"{path}.html", scope) + + +def make_app(): + app = FastAPI() + + app.add_middleware( + CORSMiddleware, + allow_origin_regex=r"^https?://(localhost|127\.0\.0\.1)(:\d+)?$", + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + @app.get("/ping") + def ping(): + return "pong" + + def settings_path(create=True): + settings_dir = os.path.join(Path.home(), ".kiln_ai") + if create and not os.path.exists(settings_dir): + os.makedirs(settings_dir) + return os.path.join(settings_dir, "settings.yaml") + + def load_settings(): + if not os.path.isfile(settings_path(create=False)): + return {} + with open(settings_path(), "r") as f: + settings = yaml.safe_load(f.read()) + return settings + + @app.post("/provider/ollama/connect") + def connect_ollama(): + # Tags is a list of Ollama models. Proves Ollama is running, and models are available. + try: + tags = requests.get("http://localhost:11434/api/tags").json() + except requests.exceptions.ConnectionError: + return JSONResponse( + status_code=417, + content={ + "message": "Failed to connect to Ollama. Ensure Ollama app is running." + }, + ) + except Exception as e: + return JSONResponse( + status_code=500, + content={"message": f"Failed to connect to Ollama: {e}"}, + ) + + if "models" in tags: + models = tags["models"] + if isinstance(models, list): + model_names = [model["model"] for model in models] + # TODO P2: check there's at least 1 model we support + if len(model_names) > 0: + return {"message": "Ollama connected", "models": model_names} -@app.post("/provider/ollama/connect") -def connect_ollama(): - # Tags is a list of Ollama models. Proves Ollama is running, and models are available. - try: - tags = requests.get("http://localhost:11434/api/tags").json() - except requests.exceptions.ConnectionError: return JSONResponse( status_code=417, - content={ - "message": "Failed to connect to Ollama. Ensure Ollama app is running." - }, - ) - except Exception as e: - return JSONResponse( - status_code=500, - content={"message": f"Failed to connect to Ollama: {e}"}, + content={"message": "Ollama not connected, or no Ollama models installed."}, ) - if "models" in tags: - models = tags["models"] - if isinstance(models, list): - model_names = [model["model"] for model in models] - # TODO P2: check there's at least 1 model we support - if len(model_names) > 0: - return {"message": "Ollama connected", "models": model_names} - - return JSONResponse( - status_code=417, - content={"message": "Ollama not connected, or no Ollama models installed."}, - ) - - -@app.post("/setting") -def update_settings(new_settings: dict[str, int | float | str | bool]): - settings = load_settings() - settings.update(new_settings) - with open(settings_path(), "w") as f: - f.write(yaml.dump(settings)) - return {"message": "Settings updated"} - + @app.post("/setting") + def update_settings(new_settings: dict[str, int | float | str | bool]): + settings = load_settings() + settings.update(new_settings) + with open(settings_path(), "w") as f: + f.write(yaml.dump(settings)) + return {"message": "Settings updated"} -@app.get("/settings") -def read_settings(): - settings = load_settings() - return settings + @app.get("/settings") + def read_settings(): + settings = load_settings() + return settings + @app.get("/items/{item_id}") + def read_item(item_id: int, q: Union[str, None] = None): + return {"item_id": item_id, "q": q} -@app.get("/items/{item_id}") -def read_item(item_id: int, q: Union[str, None] = None): - return {"item_id": item_id, "q": q} + # Web UI + # Ensure studio_path exists (test servers don't necessarily create it) + os.makedirs(studio_path(), exist_ok=True) + # Serves the web UI at root + app.mount("/", HTMLStaticFiles(directory=studio_path(), html=True), name="studio") -# Web UI -# File server that maps /foo/bar to /foo/bar.html (Starlette StaticFiles only does index.html) -class HTMLStaticFiles(StaticFiles): - async def get_response(self, path: str, scope): - try: - response = await super().get_response(path, scope) - return response - except Exception as e: - # catching HTTPException explicitly not working for some reason - if getattr(e, "status_code", None) == 404: - # Return the .html version of the file if the .html version exists - return await super().get_response(f"{path}.html", scope) - raise e + @app.exception_handler(404) + def not_found_exception_handler(request, exc): + return FileResponse(os.path.join(studio_path(), "404.html"), status_code=404) + return app -# Ensure studio_path exists (test servers don't necessarily create it) -os.makedirs(studio_path(), exist_ok=True) -# Serves the web UI at root -app.mount("/", HTMLStaticFiles(directory=studio_path(), html=True), name="studio") if __name__ == "__main__": + app = make_app() uvicorn.run(app, host="127.0.0.1", port=8757) diff --git a/libs/studio/kiln_studio/test_server.py b/libs/studio/kiln_studio/test_server.py index 2e196d3..a902d75 100644 --- a/libs/studio/kiln_studio/test_server.py +++ b/libs/studio/kiln_studio/test_server.py @@ -6,20 +6,32 @@ import requests from fastapi import HTTPException from fastapi.testclient import TestClient -from kiln_studio.server import HTMLStaticFiles, studio_path +from kiln_studio.server import HTMLStaticFiles -from libs.studio.kiln_studio.server import app +from libs.studio.kiln_studio.server import make_app -client = TestClient(app) +@pytest.fixture +def client(): + # a client based on a mock studio path + with tempfile.TemporaryDirectory() as temp_dir: + os.makedirs(temp_dir, exist_ok=True) + with patch("libs.studio.kiln_studio.server.studio_path", new=lambda: temp_dir): + from libs.studio.kiln_studio.server import studio_path + + assert studio_path() == temp_dir # Verify the patch is working + app = make_app() + client = TestClient(app) + yield client -def test_ping(): + +def test_ping(client): response = client.get("/ping") assert response.status_code == 200 assert response.json() == "pong" -def test_connect_ollama_success(): +def test_connect_ollama_success(client): with patch("requests.get") as mock_get: mock_get.return_value.json.return_value = { "models": [{"model": "model1"}, {"model": "model2"}] @@ -32,7 +44,7 @@ def test_connect_ollama_success(): } -def test_connect_ollama_connection_error(): +def test_connect_ollama_connection_error(client): with patch("requests.get") as mock_get: mock_get.side_effect = requests.exceptions.ConnectionError response = client.post("/provider/ollama/connect") @@ -42,7 +54,7 @@ def test_connect_ollama_connection_error(): } -def test_connect_ollama_general_exception(): +def test_connect_ollama_general_exception(client): with patch("requests.get") as mock_get: mock_get.side_effect = Exception("Test exception") response = client.post("/provider/ollama/connect") @@ -52,7 +64,7 @@ def test_connect_ollama_general_exception(): } -def test_connect_ollama_no_models(): +def test_connect_ollama_no_models(client): with patch("requests.get") as mock_get: mock_get.return_value.json.return_value = {"models": []} response = client.post("/provider/ollama/connect") @@ -73,7 +85,7 @@ def test_connect_ollama_no_models(): "https://127.0.0.1:8443", ], ) -def test_cors_allowed_origins(origin): +def test_cors_allowed_origins(client, origin): response = client.get("/ping", headers={"Origin": origin}) assert response.status_code == 200 assert response.headers["access-control-allow-origin"] == origin @@ -90,20 +102,15 @@ def test_cors_allowed_origins(origin): "http://127.0.0.2.com", ], ) -def test_cors_blocked_origins(origin): +def test_cors_blocked_origins(client, origin): response = client.get("/ping", headers={"Origin": origin}) assert response.status_code == 200 assert "access-control-allow-origin" not in response.headers -@pytest.fixture -def mock_studio_path(): - with tempfile.TemporaryDirectory() as temp_dir: - with patch("kiln_studio.server.studio_path", return_value=temp_dir): - yield temp_dir - - def create_studio_test_file(relative_path): + from libs.studio.kiln_studio.server import studio_path + full_path = os.path.join(studio_path(), relative_path) os.makedirs(os.path.dirname(full_path), exist_ok=True) with open(full_path, "w") as f: @@ -111,7 +118,7 @@ def create_studio_test_file(relative_path): return full_path -def test_cors_no_origin(mock_studio_path): +def test_cors_no_origin(client): # Create index.html in the mock studio path create_studio_test_file("index.html") @@ -171,25 +178,30 @@ async def test_get_response_not_found(self, html_static_files): with pytest.raises(HTTPException): await html_static_files.get_response("non_existing_file", {}) - @pytest.mark.asyncio - async def test_setup_route(self, mock_studio_path): - import os - # Ensure studio_path exists - os.makedirs(studio_path(), exist_ok=True) - create_studio_test_file("index.html") - create_studio_test_file("setup.html") - create_studio_test_file("setup/connect_providers/index.html") +@pytest.mark.asyncio +async def test_setup_route(client): + # Ensure studio_path exists + create_studio_test_file("index.html") + create_studio_test_file("path.html") + create_studio_test_file("nested/index.html") - # root index.html - response = client.get("/") - assert response.status_code == 200 - # setup.html - response = client.get("/setup") - assert response.status_code == 200 - # nested index.html - response = client.get("/setup/connect_providers") - assert response.status_code == 200 - # non existing file - response = client.get("/setup/non_existing_file") - assert response.status_code == 404 + # root index.html + response = client.get("/") + assert response.status_code == 200 + assert response.text == "Test" + # setup.html + response = client.get("/path") + assert response.status_code == 200 + assert response.text == "Test" + # nested index.html + response = client.get("/nested") + assert response.status_code == 200 + assert response.text == "Test" + # non existing file + + # expected 404 + with pytest.raises(Exception): + client.get("/non_existing_file") + with pytest.raises(Exception): + client.get("/nested/non_existing_file") diff --git a/libs/studio/tests/test_import.py b/libs/studio/tests/test_import.py deleted file mode 100644 index 9601dda..0000000 --- a/libs/studio/tests/test_import.py +++ /dev/null @@ -1,7 +0,0 @@ -import kiln_ai.coreadd as coreadd -import kiln_studio.server as server - - -def test_import() -> None: - assert server.studio_path() != "" - assert coreadd.add(1, 1) == 2