diff --git a/fps/app.py b/fps/app.py index 8f0f707..c286e92 100644 --- a/fps/app.py +++ b/fps/app.py @@ -282,6 +282,7 @@ def _load_middlewares(app: FastAPI) -> None: pkg_names = {get_pkg_name(p, strip_fps=False) for p in grouped_middlewares} logger.info(f"Loading middlewares from plugin package(s) {pkg_names}") + middleware_dict = {} for p, middlewares in grouped_middlewares.items(): p_name = Config.plugin_name(p) plugin_config = Config.from_name(p_name) @@ -294,7 +295,6 @@ def _load_middlewares(app: FastAPI) -> None: Config(FPSConfig).enabled_plugins and p_name not in Config(FPSConfig).enabled_plugins ) - or p_name not in Config(FPSConfig).middlewares ) if not middlewares or disabled: disabled_msg = " (disabled)" if disabled else "" @@ -303,15 +303,30 @@ def _load_middlewares(app: FastAPI) -> None: ) continue - for plugin_middleware, plugin_kwargs in middlewares: - app.add_middleware( - plugin_middleware, - **plugin_kwargs, - ) + middleware_dict.update( + { + f"{middleware.__module__}.{middleware.__qualname__}": middleware + for middleware in middlewares + } + ) - logger.info( - f"{len(middlewares)} middleware(s) added from plugin '{p_name}'" + middleware_cnt = 0 + for middleware in Config(FPSConfig).middlewares: + middleware_class_path = middleware.class_path + if middleware_class_path not in middleware_dict: + logger.warning(f"Unknown middleware {middleware_class_path}") + continue + + logger.info(f"Adding middleware {middleware_class_path}") + middleware_class = middleware_dict[middleware_class_path] + app.add_middleware( + middleware_class, + **middleware.kwargs, ) + middleware_cnt += 1 + + logger.info(f"{middleware_cnt} middleware(s) added") + else: logger.info("No plugin middleware to load") diff --git a/fps/config.py b/fps/config.py index e9745b7..7f3bd84 100644 --- a/fps/config.py +++ b/fps/config.py @@ -2,7 +2,7 @@ import os from collections import OrderedDict from types import ModuleType -from typing import Dict, List, Tuple, Type +from typing import Any, Dict, List, Tuple, Type import toml from pydantic import BaseModel, create_model, validator @@ -25,6 +25,11 @@ def create_default_plugin_model(plugin_name: str): return create_model(f"{plugin_name}Model", __base__=PluginModel) +class MiddlewareModel(BaseModel): + class_path: str + kwargs: Dict[str, Any] + + class FPSConfig(BaseModel): # fastapi title: str = "FPS" @@ -36,7 +41,7 @@ class FPSConfig(BaseModel): disabled_plugins: List[str] = [] # plugin middlewares - middlewares: List[str] = [] + middlewares: List[MiddlewareModel] = [] @validator("enabled_plugins", "disabled_plugins", "middlewares") def plugins_format(cls, plugins): diff --git a/fps/hooks.py b/fps/hooks.py index f57dafe..5a22f13 100644 --- a/fps/hooks.py +++ b/fps/hooks.py @@ -79,13 +79,13 @@ def plugin_name_callback() -> str: @pluggy.HookspecMarker(HookType.MIDDLEWARE.value) -def middleware() -> Tuple[type, Dict[str, Any]]: +def middleware() -> type: pass -def register_middleware(m: type, **kwargs: Dict[str, Any]): - def middleware_callback() -> Tuple[type, Dict[str, Any]]: - return m, kwargs +def register_middleware(m: type): + def middleware_callback() -> type: + return m return pluggy.HookimplMarker(HookType.MIDDLEWARE.value)( function=middleware_callback, specname="middleware"