diff --git a/robyn/__init__.py b/robyn/__init__.py index 929c0971..464739a6 100644 --- a/robyn/__init__.py +++ b/robyn/__init__.py @@ -43,7 +43,7 @@ def __init__( file_object: str, config: Config = Config(), openapi_file_path: Optional[str] = None, - openapi: OpenAPI = OpenAPI(), + openapi: Optional[OpenAPI] = None, dependencies: DependencyMap = DependencyMap(), ) -> None: directory_path = os.path.dirname(os.path.abspath(file_object)) @@ -53,10 +53,7 @@ def __init__( self.dependencies = dependencies self.openapi = openapi - if openapi_file_path: - openapi.override_openapi(Path(self.directory_path).joinpath(openapi_file_path)) - elif Path(self.directory_path).joinpath("openapi.json").exists(): - openapi.override_openapi(Path(self.directory_path).joinpath("openapi.json")) + self.init_openapi(openapi_file_path) if not bool(os.environ.get("ROBYN_CLI", False)): # the env variables are already set when are running through the cli @@ -82,6 +79,20 @@ def __init__( self.event_handlers: dict = {} self.exception_handler: Optional[Callable] = None self.authentication_handler: Optional[AuthenticationHandler] = None + self.included_routers: List[Router] = [] + + def init_openapi(self, openapi_file_path: Optional[str]) -> None: + if self.config.disable_openapi: + return + + if self.openapi is None: + self.openapi = OpenAPI() + + if openapi_file_path: + self.openapi.override_openapi(Path(self.directory_path).joinpath(openapi_file_path)) + elif Path(self.directory_path).joinpath("openapi.json").exists(): + self.openapi.override_openapi(Path(self.directory_path).joinpath("openapi.json")) + # TODO! what about when the elif fails? def _handle_dev_mode(self): cli_dev_mode = self.config.dev # --dev @@ -102,6 +113,8 @@ def add_route( handler: Callable, is_const: bool = False, auth_required: bool = False, + openapi_name: str = "", + openapi_tags: Union[List[str], None] = None, ): """ Connect a URI to a handler @@ -117,6 +130,8 @@ def add_route( """ injected_dependencies = self.dependencies.get_dependency_map(self) + list_openapi_tags: List[str] = openapi_tags if openapi_tags else [] + if auth_required: self.middleware_router.add_auth_middleware(endpoint)(handler) @@ -137,6 +152,9 @@ def add_route( endpoint=endpoint, handler=handler, is_const=is_const, + auth_required=auth_required, + openapi_name=openapi_name, + openapi_tags=list_openapi_tags, exception_handler=self.exception_handler, injected_dependencies=injected_dependencies, ) @@ -240,6 +258,15 @@ def is_port_in_use(self, port: int) -> bool: raise Exception(f"Invalid port number: {port}") def _add_openapi_routes(self, auth_required: bool = False): + if self.config.disable_openapi: + return + + if self.openapi is None: + logger.error("No openAPI") + return + + self.router.prepare_routes_openapi(self.openapi, self.included_routers) + self.add_route( route_type=HttpMethod.GET, endpoint="/openapi.json", @@ -369,9 +396,7 @@ def get( """ def inner(handler): - self.openapi.add_openapi_path_obj("get", endpoint, openapi_name, openapi_tags, handler) - - return self.add_route(HttpMethod.GET, endpoint, handler, const, auth_required) + return self.add_route(HttpMethod.GET, endpoint, handler, const, auth_required, openapi_name, openapi_tags) return inner @@ -392,9 +417,7 @@ def post( """ def inner(handler): - self.openapi.add_openapi_path_obj("post", endpoint, openapi_name, openapi_tags, handler) - - return self.add_route(HttpMethod.POST, endpoint, handler, auth_required=auth_required) + return self.add_route(HttpMethod.POST, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) return inner @@ -415,9 +438,7 @@ def put( """ def inner(handler): - self.openapi.add_openapi_path_obj("put", endpoint, openapi_name, openapi_tags, handler) - - return self.add_route(HttpMethod.PUT, endpoint, handler, auth_required=auth_required) + return self.add_route(HttpMethod.PUT, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) return inner @@ -438,9 +459,7 @@ def delete( """ def inner(handler): - self.openapi.add_openapi_path_obj("delete", endpoint, openapi_name, openapi_tags, handler) - - return self.add_route(HttpMethod.DELETE, endpoint, handler, auth_required=auth_required) + return self.add_route(HttpMethod.DELETE, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) return inner @@ -461,9 +480,7 @@ def patch( """ def inner(handler): - self.openapi.add_openapi_path_obj("patch", endpoint, openapi_name, openapi_tags, handler) - - return self.add_route(HttpMethod.PATCH, endpoint, handler, auth_required=auth_required) + return self.add_route(HttpMethod.PATCH, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) return inner @@ -484,9 +501,7 @@ def head( """ def inner(handler): - self.openapi.add_openapi_path_obj("head", endpoint, openapi_name, openapi_tags, handler) - - return self.add_route(HttpMethod.HEAD, endpoint, handler, auth_required=auth_required) + return self.add_route(HttpMethod.HEAD, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) return inner @@ -507,9 +522,7 @@ def options( """ def inner(handler): - self.openapi.add_openapi_path_obj("options", endpoint, openapi_name, openapi_tags, handler) - - return self.add_route(HttpMethod.OPTIONS, endpoint, handler, auth_required=auth_required) + return self.add_route(HttpMethod.OPTIONS, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) return inner @@ -530,8 +543,7 @@ def connect( """ def inner(handler): - self.openapi.add_openapi_path_obj("connect", endpoint, openapi_name, openapi_tags, handler) - return self.add_route(HttpMethod.CONNECT, endpoint, handler, auth_required=auth_required) + return self.add_route(HttpMethod.CONNECT, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) return inner @@ -552,9 +564,7 @@ def trace( """ def inner(handler): - self.openapi.add_openapi_path_obj("trace", endpoint, openapi_name, openapi_tags, handler) - - return self.add_route(HttpMethod.TRACE, endpoint, handler, auth_required=auth_required) + return self.add_route(HttpMethod.TRACE, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) return inner @@ -564,11 +574,14 @@ def include_router(self, router): :param router Robyn: the router object to include the routes from """ + self.included_routers.append(router) + self.router.routes.extend(router.router.routes) self.middleware_router.global_middlewares.extend(router.middleware_router.global_middlewares) self.middleware_router.route_middlewares.extend(router.middleware_router.route_middlewares) - self.openapi.add_subrouter_paths(router.openapi) + if not self.config.disable_openapi and self.openapi is not None: + self.openapi.add_subrouter_paths(self.openapi) # extend the websocket routes prefix = router.prefix diff --git a/robyn/processpool.py b/robyn/processpool.py index f1b0839f..7669ec85 100644 --- a/robyn/processpool.py +++ b/robyn/processpool.py @@ -182,7 +182,7 @@ def spawn_process( server.set_response_headers_exclude_paths(excluded_response_headers_paths) for route in routes: - route_type, endpoint, function, is_const = route + route_type, endpoint, function, is_const, auth_required, openapi_name, openapi_tags = route server.add_route(route_type, endpoint, function, is_const) for middleware_type, middleware_function in global_middlewares: diff --git a/robyn/router.py b/robyn/router.py index 655aa525..774237fe 100644 --- a/robyn/router.py +++ b/robyn/router.py @@ -11,6 +11,7 @@ from robyn.authentication import AuthenticationHandler, AuthenticationNotConfiguredError from robyn.dependency_injection import DependencyMap from robyn.jsonify import jsonify +from robyn.openapi import OpenAPI from robyn.responses import FileResponse from robyn.robyn import FunctionInfo, Headers, HttpMethod, Identity, MiddlewareType, QueryParams, Request, Response, Url from robyn.types import Body, Files, FormData, IPAddress, Method, PathParams @@ -19,11 +20,18 @@ _logger = logging.getLogger(__name__) +def lower_http_method(method: HttpMethod): + return (str(method))[11:].lower() + + class Route(NamedTuple): route_type: HttpMethod route: str function: FunctionInfo is_const: bool + auth_required: bool + openapi_name: str + openapi_tags: List[str] class RouteMiddleware(NamedTuple): @@ -108,6 +116,9 @@ def add_route( # type: ignore endpoint: str, handler: Callable, is_const: bool, + auth_required: bool, + openapi_name: str, + openapi_tags: List[str], exception_handler: Optional[Callable], injected_dependencies: dict, ) -> Union[Callable, CoroutineType]: @@ -226,7 +237,7 @@ def inner_handler(*args, **kwargs): params, new_injected_dependencies, ) - self.routes.append(Route(route_type, endpoint, function, is_const)) + self.routes.append(Route(route_type, endpoint, function, is_const, auth_required, openapi_name, openapi_tags)) return async_inner_handler else: function = FunctionInfo( @@ -236,9 +247,18 @@ def inner_handler(*args, **kwargs): params, new_injected_dependencies, ) - self.routes.append(Route(route_type, endpoint, function, is_const)) + self.routes.append(Route(route_type, endpoint, function, is_const, auth_required, openapi_name, openapi_tags)) return inner_handler + def prepare_routes_openapi(self, openapi: OpenAPI, included_routers: List) -> None: + for route in self.routes: + openapi.add_openapi_path_obj(lower_http_method(route.route_type), route.route, route.openapi_name, route.openapi_tags, route.function.handler) + + # TODO! after include_routes does not immediatelly merge all the routes + # for router in included_routers: + # for route in router: + # openapi.add_openapi_path_obj(lower_http_method(route.route_type), route.route, route.openapi_name, route.openapi_tags, route.function.handler) + def get_routes(self) -> List[Route]: return self.routes