diff --git a/fps/app.py b/fps/app.py index 8f0f707..cf203db 100644 --- a/fps/app.py +++ b/fps/app.py @@ -282,6 +282,7 @@ def _load_middlewares(app: FastAPI) -> None: pkg_names = {get_pkg_name(p, strip_fps=False) for p in grouped_middlewares} logger.info(f"Loading middlewares from plugin package(s) {pkg_names}") + middleware_dict = {} for p, middlewares in grouped_middlewares.items(): p_name = Config.plugin_name(p) plugin_config = Config.from_name(p_name) @@ -294,7 +295,6 @@ def _load_middlewares(app: FastAPI) -> None: Config(FPSConfig).enabled_plugins and p_name not in Config(FPSConfig).enabled_plugins ) - or p_name not in Config(FPSConfig).middlewares ) if not middlewares or disabled: disabled_msg = " (disabled)" if disabled else "" @@ -303,15 +303,36 @@ def _load_middlewares(app: FastAPI) -> None: ) continue - for plugin_middleware, plugin_kwargs in middlewares: - app.add_middleware( - plugin_middleware, - **plugin_kwargs, + logger.info(f"Registered middleware(s) for plugin '{p_name}':") + for middleware in middlewares: + logger.info( + f"Middleware: {middleware.__module__}.{middleware.__qualname__}" ) - logger.info( - f"{len(middlewares)} middleware(s) added from plugin '{p_name}'" + middleware_dict.update( + { + f"{middleware.__module__}.{middleware.__qualname__}": middleware + for middleware in middlewares + } ) + + middleware_cnt = 0 + for middleware in Config(FPSConfig).middlewares: + middleware_class_path = middleware.class_path + if middleware_class_path not in middleware_dict: + logger.warning(f"Unknown middleware {middleware_class_path}") + continue + + logger.info(f"Adding middleware {middleware_class_path}") + middleware_class = middleware_dict[middleware_class_path] + app.add_middleware( + middleware_class, + **middleware.kwargs, + ) + middleware_cnt += 1 + + logger.info(f"{middleware_cnt} middleware(s) added") + else: logger.info("No plugin middleware to load") diff --git a/fps/config.py b/fps/config.py index e9745b7..f89cba4 100644 --- a/fps/config.py +++ b/fps/config.py @@ -2,7 +2,7 @@ import os from collections import OrderedDict from types import ModuleType -from typing import Dict, List, Tuple, Type +from typing import Any, Dict, List, Tuple, Type import toml from pydantic import BaseModel, create_model, validator @@ -25,6 +25,11 @@ def create_default_plugin_model(plugin_name: str): return create_model(f"{plugin_name}Model", __base__=PluginModel) +class MiddlewareModel(BaseModel): + class_path: str + kwargs: Dict[str, Any] = {} + + class FPSConfig(BaseModel): # fastapi title: str = "FPS" @@ -36,9 +41,9 @@ class FPSConfig(BaseModel): disabled_plugins: List[str] = [] # plugin middlewares - middlewares: List[str] = [] + middlewares: List[MiddlewareModel] = [] - @validator("enabled_plugins", "disabled_plugins", "middlewares") + @validator("enabled_plugins", "disabled_plugins") def plugins_format(cls, plugins): warnings = [p for p in plugins if p.startswith("[") or p.endswith("]")] if warnings: diff --git a/fps/hooks.py b/fps/hooks.py index f57dafe..5a22f13 100644 --- a/fps/hooks.py +++ b/fps/hooks.py @@ -79,13 +79,13 @@ def plugin_name_callback() -> str: @pluggy.HookspecMarker(HookType.MIDDLEWARE.value) -def middleware() -> Tuple[type, Dict[str, Any]]: +def middleware() -> type: pass -def register_middleware(m: type, **kwargs: Dict[str, Any]): - def middleware_callback() -> Tuple[type, Dict[str, Any]]: - return m, kwargs +def register_middleware(m: type): + def middleware_callback() -> type: + return m return pluggy.HookimplMarker(HookType.MIDDLEWARE.value)( function=middleware_callback, specname="middleware" diff --git a/plugins/uvicorn/fps_uvicorn/cli.py b/plugins/uvicorn/fps_uvicorn/cli.py index 5c7a54a..fa09f75 100644 --- a/plugins/uvicorn/fps_uvicorn/cli.py +++ b/plugins/uvicorn/fps_uvicorn/cli.py @@ -2,6 +2,7 @@ import os import threading import webbrowser +from pathlib import Path from typing import Any, Dict, List import toml @@ -86,7 +87,7 @@ def store_extra_options(options: Dict[str, Any]): f_name = "fps_cli_args.toml" with open(f_name, "w") as f: toml.dump(opts, f) - os.environ["FPS_CLI_CONFIG_FILE"] = f_name + os.environ["FPS_CLI_CONFIG_FILE"] = str(Path(f_name).resolve()) @app.command(