From 9ed7475b0733df2895c79eeb864b1f10c6c2e7bd Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Tue, 27 Aug 2024 23:32:41 +0200 Subject: [PATCH 01/12] typing: type app --- docs/api/media.rst | 8 +- docs/api/middleware.rst | 67 ++++- docs/api/websocket.rst | 14 +- e2e-tests/server/app.py | 3 +- e2e-tests/server/ping.py | 2 +- falcon/app.py | 321 +++++++++++++-------- falcon/app_helpers.py | 118 ++++++-- falcon/asgi/_asgi_helpers.py | 16 +- falcon/asgi/app.py | 372 ++++++++++++++++++------- falcon/asgi/ws.py | 19 +- falcon/asgi_spec.py | 4 +- falcon/http_status.py | 25 +- falcon/inspect.py | 8 +- falcon/media/base.py | 6 +- falcon/media/handlers.py | 7 +- falcon/media/json.py | 2 +- falcon/media/multipart.py | 15 +- falcon/response.py | 8 +- falcon/routing/compiled.py | 16 +- falcon/routing/static.py | 53 +++- falcon/testing/test_case.py | 2 +- falcon/typing.py | 106 ++++++- falcon/util/__init__.py | 2 +- falcon/util/mediatypes.py | 12 +- falcon/util/misc.py | 4 +- pyproject.toml | 4 +- tests/asgi/test_hello_asgi.py | 4 +- tests/asgi/test_response_media_asgi.py | 4 +- tests/asgi/test_ws.py | 16 +- tests/test_error_handlers.py | 3 - tests/test_hello.py | 2 +- tests/test_httperror.py | 4 +- tests/test_media_multipart.py | 2 +- tests/test_request_media.py | 2 +- tests/test_response_media.py | 2 +- tests/test_utils.py | 2 +- 36 files changed, 887 insertions(+), 368 deletions(-) diff --git a/docs/api/media.rst b/docs/api/media.rst index e48d8ac73..dbff05383 100644 --- a/docs/api/media.rst +++ b/docs/api/media.rst @@ -115,16 +115,20 @@ middleware. Here is an example of how this can be done: .. code:: python + from falcon import Request, Response + class NegotiationMiddleware: - def process_request(self, req, resp): + def process_request(self, req: Request, resp: Response) -> None: resp.content_type = req.accept .. tab:: ASGI .. code:: python + from falcon.asgi import Request, Response + class NegotiationMiddleware: - async def process_request(self, req, resp): + async def process_request(self, req: Request, resp: Response) -> None: resp.content_type = req.accept diff --git a/docs/api/middleware.rst b/docs/api/middleware.rst index ea6db856f..3cfbb18cd 100644 --- a/docs/api/middleware.rst +++ b/docs/api/middleware.rst @@ -26,8 +26,11 @@ defined below. .. code:: python + from typing import Any + from falcon import Request, Response + class ExampleMiddleware: - def process_request(self, req, resp): + def process_request(self, req: Request, resp: Response) -> None: """Process the request before routing it. Note: @@ -42,7 +45,13 @@ defined below. the on_* responder. """ - def process_resource(self, req, resp, resource, params): + def process_resource( + self, + req: Request, + resp: Response, + resource: object, + params: dict[str, Any], + ) -> None: """Process the request after routing. Note: @@ -62,7 +71,13 @@ defined below. method as keyword arguments. """ - def process_response(self, req, resp, resource, req_succeeded): + def process_response( + self, + req: Request, + resp: Response, + resource: object, + req_succeeded: bool + ) -> None: """Post-processing of the response (after routing). Args: @@ -90,8 +105,13 @@ defined below. .. code:: python + from typing import Any + from falcon.asgi import Request, Response, WebSocket + class ExampleMiddleware: - async def process_startup(self, scope, event): + async def process_startup( + self, scope: dict[str, Any], event: dict[str, Any] + ) -> None: """Process the ASGI lifespan startup event. Invoked when the server is ready to start up and @@ -111,7 +131,9 @@ defined below. startup event. """ - async def process_shutdown(self, scope, event): + async def process_shutdown( + self, scope: dict[str, Any], event: dict[str, Any] + ) -> None: """Process the ASGI lifespan shutdown event. Invoked when the server has stopped accepting @@ -130,7 +152,7 @@ defined below. shutdown event. """ - async def process_request(self, req, resp): + async def process_request(self, req: Request, resp: Response) -> None: """Process the request before routing it. Note: @@ -145,7 +167,13 @@ defined below. the on_* responder. """ - async def process_resource(self, req, resp, resource, params): + async def process_resource( + self, + req: Request, + resp: Response, + resource: object, + params: dict[str, Any], + ) -> None: """Process the request after routing. Note: @@ -165,7 +193,13 @@ defined below. method as keyword arguments. """ - async def process_response(self, req, resp, resource, req_succeeded): + async def process_response( + self, + req: Request, + resp: Response, + resource: object, + req_succeeded: bool + ) -> None: """Post-processing of the response (after routing). Args: @@ -179,7 +213,7 @@ defined below. otherwise False. """ - async def process_request_ws(self, req, ws): + async def process_request_ws(self, req: Request, ws: WebSocket) -> None: """Process a WebSocket handshake request before routing it. Note: @@ -194,7 +228,13 @@ defined below. on_websocket() after routing. """ - async def process_resource_ws(self, req, ws, resource, params): + async def process_resource_ws( + self, + req: Request, + ws: WebSocket, + resource: object, + params: dict[str, Any], + ) -> None: """Process a WebSocket handshake request after routing. Note: @@ -226,15 +266,18 @@ the following example: .. code:: python + import falcon as wsgi + from falcon import asgi + class ExampleMiddleware: - def process_request(self, req, resp): + def process_request(self, req: wsgi.Request, resp: wsgi.Response) -> None: """Process WSGI request using synchronous logic. Note that req and resp are instances of falcon.Request and falcon.Response, respectively. """ - async def process_request_async(self, req, resp): + async def process_request_async(self, req: asgi.Request, resp: asgi.Response) -> None: """Process ASGI request using asynchronous logic. Note that req and resp are instances of falcon.asgi.Request and diff --git a/docs/api/websocket.rst b/docs/api/websocket.rst index 66b457722..38878f964 100644 --- a/docs/api/websocket.rst +++ b/docs/api/websocket.rst @@ -35,8 +35,12 @@ middleware objects configured for the app: .. code:: python + from typing import Any + from falcon.asgi import Request, WebSocket + + class SomeMiddleware: - async def process_request_ws(self, req, ws): + async def process_request_ws(self, req: Request, ws: WebSocket) -> None: """Process a WebSocket handshake request before routing it. Note: @@ -51,7 +55,13 @@ middleware objects configured for the app: on_websocket() after routing. """ - async def process_resource_ws(self, req, ws, resource, params): + async def process_resource_ws( + self, + req: Request, + ws: WebSocket, + resource: object, + params: dict[str, Any], + ) -> None: """Process a WebSocket handshake request after routing. Note: diff --git a/e2e-tests/server/app.py b/e2e-tests/server/app.py index 46f52a90c..be9558985 100644 --- a/e2e-tests/server/app.py +++ b/e2e-tests/server/app.py @@ -13,8 +13,7 @@ def create_app() -> falcon.asgi.App: - # TODO(vytas): Type to App's constructor. - app = falcon.asgi.App() # type: ignore + app = falcon.asgi.App() hub = Hub() app.add_route('/ping', Pong()) diff --git a/e2e-tests/server/ping.py b/e2e-tests/server/ping.py index 447db6658..cc7537464 100644 --- a/e2e-tests/server/ping.py +++ b/e2e-tests/server/ping.py @@ -10,4 +10,4 @@ async def on_get(self, req: Request, resp: Response) -> None: resp.content_type = falcon.MEDIA_TEXT resp.text = 'PONG\n' # TODO(vytas): Properly type Response.status. - resp.status = HTTPStatus.OK # type: ignore + resp.status = HTTPStatus.OK diff --git a/falcon/app.py b/falcon/app.py index a71c6530d..88ab0e554 100644 --- a/falcon/app.py +++ b/falcon/app.py @@ -19,7 +19,25 @@ import pathlib import re import traceback -from typing import Callable, Iterable, Optional, Tuple, Type, Union +from typing import ( + Any, + Callable, + cast, + ClassVar, + Dict, + FrozenSet, + IO, + Iterable, + List, + Literal, + Optional, + overload, + Pattern, + Tuple, + Type, + TypeVar, + Union, +) import warnings from falcon import app_helpers as helpers @@ -37,9 +55,18 @@ from falcon.response import Response from falcon.response import ResponseOptions import falcon.status_codes as status +from falcon.typing import AsgiResponderCallable +from falcon.typing import AsgiResponderWsCallable +from falcon.typing import AsgiSinkCallable from falcon.typing import ErrorHandler from falcon.typing import ErrorSerializer +from falcon.typing import FindMethod +from falcon.typing import ProcessResponseMethod +from falcon.typing import ResponderCallable +from falcon.typing import SinkCallable from falcon.typing import SinkPrefix +from falcon.typing import StartResponse +from falcon.typing import WSGIEnvironment from falcon.util import deprecation from falcon.util import misc from falcon.util.misc import code_to_http_status @@ -61,10 +88,11 @@ status.HTTP_304, ] ) +_BE = TypeVar('_BE', bound=BaseException) class App: - """The main entry point into a Falcon-based WSGI app. + '''The main entry point into a Falcon-based WSGI app. Each App instance provides a callable `WSGI `_ interface @@ -90,9 +118,9 @@ class App: to implement the methods for the events you would like to handle; Falcon simply skips over any missing middleware methods:: - class ExampleComponent: - def process_request(self, req, resp): - \"\"\"Process the request before routing it. + class ExampleMiddleware: + def process_request(self, req: Request, resp: Response) -> None: + """Process the request before routing it. Note: Because Falcon routes each request based on @@ -105,10 +133,16 @@ def process_request(self, req, resp): routed to an on_* responder method. resp: Response object that will be routed to the on_* responder. - \"\"\" + """ - def process_resource(self, req, resp, resource, params): - \"\"\"Process the request and resource *after* routing. + def process_resource( + self, + req: Request, + resp: Response, + resource: object, + params: dict[str, Any], + ) -> None: + """Process the request and resource *after* routing. Note: This method is only called when the request matches @@ -127,10 +161,16 @@ def process_resource(self, req, resp, resource, params): template fields, that will be passed to the resource's responder method as keyword arguments. - \"\"\" + """ - def process_response(self, req, resp, resource, req_succeeded) - \"\"\"Post-processing of the response (after routing). + def process_response( + self, + req: Request, + resp: Response, + resource: object, + req_succeeded: bool + ) -> None: + """Post-processing of the response (after routing). Args: req: Request object. @@ -141,7 +181,7 @@ def process_response(self, req, resp, resource, req_succeeded) req_succeeded: True if no exceptions were raised while the framework processed and routed the request; otherwise False. - \"\"\" + """ (See also: :ref:`Middleware `) @@ -177,46 +217,33 @@ def process_response(self, req, resp, resource, req_succeeded) sink_before_static_route (bool): Indicates if the sinks should be processed before (when ``True``) or after (when ``False``) the static routes. This has an effect only if no route was matched. (default ``True``) + ''' - Attributes: - req_options: A set of behavioral options related to incoming - requests. (See also: :class:`~.RequestOptions`) - resp_options: A set of behavioral options related to outgoing - responses. (See also: :class:`~.ResponseOptions`) - router_options: Configuration options for the router. If a - custom router is in use, and it does not expose any - configurable options, referencing this attribute will raise - an instance of ``AttributeError``. - - (See also: :ref:`CompiledRouterOptions `) - """ - - _META_METHODS = frozenset(constants._META_METHODS) + _META_METHODS: ClassVar[FrozenSet[str]] = frozenset(constants._META_METHODS) - _STREAM_BLOCK_SIZE = 8 * 1024 # 8 KiB + _STREAM_BLOCK_SIZE: ClassVar[int] = 8 * 1024 # 8 KiB - _STATIC_ROUTE_TYPE = routing.StaticRoute + _STATIC_ROUTE_TYPE: ClassVar[Type[routing.StaticRoute]] = routing.StaticRoute # NOTE(kgriffs): This makes it easier to tell what we are dealing with # without having to import falcon.asgi. - _ASGI = False + _ASGI: ClassVar[bool] = False # NOTE(kgriffs): We do it like this rather than just implementing the # methods directly on the class, so that we keep all the default # responders colocated in the same module. This will make it more # likely that the implementations of the async and non-async versions # of the methods are kept in sync (pun intended). - _default_responder_bad_request = responders.bad_request - _default_responder_path_not_found = responders.path_not_found + _default_responder_bad_request: ClassVar[ResponderCallable] = responders.bad_request + _default_responder_path_not_found: ClassVar[ResponderCallable] = ( + responders.path_not_found + ) __slots__ = ( '_cors_enable', '_error_handlers', '_independent_middleware', '_middleware', - # NOTE(kgriffs): WebSocket is currently only supported for - # ASGI apps, but we may add support for WSGI at some point. - '_middleware_ws', '_request_type', '_response_type', '_router_search', @@ -231,20 +258,57 @@ def process_response(self, req, resp, resource, req_succeeded) 'resp_options', ) + _cors_enable: bool + _error_handlers: Dict[Type[BaseException], ErrorHandler] + _independent_middleware: bool + _middleware: helpers.PreparedMiddlewareResult + _request_type: Type[Request] + _response_type: Type[Response] + _router_search: FindMethod + # NOTE(caselit): this should actually be a protocol of the methods required + # by a router, hardcoded to CompiledRouter for convenience for now. + _router: routing.CompiledRouter + _serialize_error: ErrorSerializer + _sink_and_static_routes: Tuple[ + Tuple[ + Union[Pattern[str], routing.StaticRoute], + Union[SinkCallable, AsgiSinkCallable, routing.StaticRoute], + bool, + ], + ..., + ] + _sink_before_static_route: bool + _sinks: List[ + Tuple[Pattern[str], Union[SinkCallable, AsgiSinkCallable], Literal[True]] + ] + _static_routes: List[ + Tuple[routing.StaticRoute, routing.StaticRoute, Literal[False]] + ] + _unprepared_middleware: List[object] + + # Attributes req_options: RequestOptions + """A set of behavioral options related to incoming requests. + + See also: :class:`~.RequestOptions` + """ resp_options: ResponseOptions + """A set of behavioral options related to outgoing responses. + + See also: :class:`~.ResponseOptions` + """ def __init__( self, - media_type=constants.DEFAULT_MEDIA_TYPE, - request_type=Request, - response_type=Response, - middleware=None, - router=None, - independent_middleware=True, - cors_enable=False, - sink_before_static_route=True, - ): + media_type: str = constants.DEFAULT_MEDIA_TYPE, + request_type: Type[Request] = Request, + response_type: Type[Response] = Response, + middleware: Union[object, Iterable[object]] = None, + router: Optional[routing.CompiledRouter] = None, + independent_middleware: bool = True, + cors_enable: bool = False, + sink_before_static_route: bool = True, + ) -> None: self._cors_enable = cors_enable self._sink_before_static_route = sink_before_static_route self._sinks = [] @@ -261,9 +325,8 @@ def __init__( # NOTE(kgriffs): Check to see if middleware is an # iterable, and if so, append the CORSMiddleware # instance. - iter(middleware) - middleware = list(middleware) - middleware.append(cm) + middleware = list(middleware) # type: ignore[arg-type] + middleware.append(cm) # type: ignore[arg-type] except TypeError: # NOTE(kgriffs): Assume the middleware kwarg references # a single middleware component. @@ -295,7 +358,7 @@ def __init__( self.add_error_handler(HTTPStatus, self._http_status_handler) def __call__( # noqa: C901 - self, env: dict, start_response: Callable + self, env: WSGIEnvironment, start_response: StartResponse ) -> Iterable[bytes]: """WSGI `app` method. @@ -314,10 +377,9 @@ def __call__( # noqa: C901 req = self._request_type(env, options=self.req_options) resp = self._response_type(options=self.resp_options) resource: Optional[object] = None - responder: Optional[Callable] = None - params: dict = {} + params: Dict[str, Any] = {} - dependent_mw_resp_stack: list = [] + dependent_mw_resp_stack: List[ProcessResponseMethod] = [] mw_req_stack, mw_rsrc_stack, mw_resp_stack = self._middleware req_succeeded = False @@ -334,15 +396,15 @@ def __call__( # noqa: C901 # response middleware after request middleware succeeds. if self._independent_middleware: for process_request in mw_req_stack: - process_request(req, resp) + process_request(req, resp) # type: ignore[operator] if resp.complete: break else: - for process_request, process_response in mw_req_stack: + for process_request, process_response in mw_req_stack: # type: ignore[assignment,misc] if process_request and not resp.complete: - process_request(req, resp) + process_request(req, resp) # type: ignore[operator] if process_response: - dependent_mw_resp_stack.insert(0, process_response) + dependent_mw_resp_stack.insert(0, process_response) # type: ignore[arg-type] if not resp.complete: # NOTE(warsaw): Moved this to inside the try except @@ -352,7 +414,8 @@ def __call__( # noqa: C901 # next-hop child resource. In that case, the object # being asked to dispatch to its child will raise an # HTTP exception signalling the problem, e.g. a 404. - responder, params, resource, req.uri_template = self._get_responder(req) + responder: ResponderCallable + responder, params, resource, req.uri_template = self._get_responder(req) # type: ignore[assignment] except Exception as ex: if not self._handle_exception(req, resp, ex, params): raise @@ -372,7 +435,7 @@ def __call__( # noqa: C901 break if not resp.complete: - responder(req, resp, **params) # type: ignore + responder(req, resp, **params) req_succeeded = True except Exception as ex: @@ -389,8 +452,8 @@ def __call__( # noqa: C901 req_succeeded = False - body = [] - length = 0 + body: Iterable[bytes] = [] + length: Optional[int] = 0 try: body, length = self._get_body(resp, env.get('wsgi.file_wrapper')) @@ -400,8 +463,8 @@ def __call__( # noqa: C901 req_succeeded = False - resp_status = code_to_http_status(resp.status) - default_media_type = self.resp_options.default_media_type + resp_status: str = code_to_http_status(resp.status) + default_media_type: Optional[str] = self.resp_options.default_media_type if req.method == 'HEAD' or resp_status in _BODILESS_STATUS_CODES: body = [] @@ -439,17 +502,27 @@ def __call__( # noqa: C901 if length is not None: resp._headers['content-length'] = str(length) - headers = resp._wsgi_headers(default_media_type) + headers: List[Tuple[str, str]] = resp._wsgi_headers(default_media_type) # Return the response per the WSGI spec. start_response(resp_status, headers) return body + # NOTE(caselit): the return type depends on the router, hardcoded to + # CompiledRouterOptions for convenience. @property - def router_options(self): + def router_options(self) -> routing.CompiledRouterOptions: + """Configuration options for the router. + + If a custom router is in use, and it does not expose any + configurable options, referencing this attribute will raise + an instance of ``AttributeError``. + + See also: :ref:`CompiledRouterOptions `. + """ return self._router.options - def add_middleware(self, middleware: Union[object, Iterable]) -> None: + def add_middleware(self, middleware: Union[object, Iterable[object]]) -> None: """Add one or more additional middleware components. Arguments: @@ -463,7 +536,7 @@ def add_middleware(self, middleware: Union[object, Iterable]) -> None: # the chance that middleware may be None. if middleware: try: - middleware = list(middleware) # type: ignore + middleware = list(middleware) # type: ignore[call-overload] except TypeError: # middleware is not iterable; assume it is just one bare component middleware = [middleware] @@ -473,7 +546,7 @@ def add_middleware(self, middleware: Union[object, Iterable]) -> None: and len( [ mc - for mc in self._unprepared_middleware + middleware + for mc in self._unprepared_middleware + middleware # type: ignore[operator] if isinstance(mc, CORSMiddleware) ] ) @@ -484,7 +557,7 @@ def add_middleware(self, middleware: Union[object, Iterable]) -> None: 'cors_enable (which already constructs one instance)' ) - self._unprepared_middleware += middleware + self._unprepared_middleware += middleware # type: ignore[arg-type] # NOTE(kgriffs): Even if middleware is None or an empty list, we still # need to make sure self._middleware is initialized if this is the @@ -494,7 +567,7 @@ def add_middleware(self, middleware: Union[object, Iterable]) -> None: independent_middleware=self._independent_middleware, ) - def add_route(self, uri_template: str, resource: object, **kwargs): + def add_route(self, uri_template: str, resource: object, **kwargs: Any) -> None: """Associate a templatized URI path with a resource. Falcon routes incoming requests to resources based on a set of @@ -606,7 +679,7 @@ def add_static_route( directory: Union[str, pathlib.Path], downloadable: bool = False, fallback_filename: Optional[str] = None, - ): + ) -> None: """Add a route to a directory of static files. Static routes provide a way to serve files directly. This @@ -674,7 +747,7 @@ def add_static_route( self._static_routes.insert(0, (sr, sr, False)) self._update_sink_and_static_routes() - def add_sink(self, sink: Callable, prefix: SinkPrefix = r'/') -> None: + def add_sink(self, sink: SinkCallable, prefix: SinkPrefix = r'/') -> None: """Register a sink method for the App. If no route matches a request, but the path in the requested URI @@ -720,6 +793,8 @@ def add_sink(self, sink: Callable, prefix: SinkPrefix = r'/') -> None: if not hasattr(prefix, 'match'): # Assume it is a string prefix = re.compile(prefix) + else: + prefix = cast(Pattern[str], prefix) # NOTE(kgriffs): Insert at the head of the list such that # in the case of a duplicate prefix, the last one added @@ -727,11 +802,25 @@ def add_sink(self, sink: Callable, prefix: SinkPrefix = r'/') -> None: self._sinks.insert(0, (prefix, sink, True)) self._update_sink_and_static_routes() + @overload + def add_error_handler( + self, + exception: Type[_BE], + handler: Callable[[Request, Response, _BE, Dict[str, Any]], None], + ) -> None: ... + + @overload def add_error_handler( self, exception: Union[Type[BaseException], Iterable[Type[BaseException]]], handler: Optional[ErrorHandler] = None, - ): + ) -> None: ... + + def add_error_handler( # type: ignore[misc] + self, + exception: Union[Type[BaseException], Iterable[Type[BaseException]]], + handler: Optional[ErrorHandler] = None, + ) -> None: """Register a handler for one or more exception types. Error handlers may be registered for any exception type, including @@ -810,34 +899,22 @@ def handle(req, resp, ex, params): """ - def wrap_old_handler(old_handler): - # NOTE(kgriffs): This branch *is* actually tested by - # test_error_handlers.test_handler_signature_shim_asgi() (as - # verified manually via pdb), but for some reason coverage - # tracking isn't picking it up. - if iscoroutinefunction(old_handler): # pragma: no cover - - @wraps(old_handler) - async def handler_async(req, resp, ex, params): - await old_handler(ex, req, resp, params) - - return handler_async - + def wrap_old_handler(old_handler: Callable[..., Any]) -> ErrorHandler: @wraps(old_handler) - def handler(req, resp, ex, params): + def handler( + req: Request, resp: Response, ex: BaseException, params: Dict[str, Any] + ) -> None: old_handler(ex, req, resp, params) return handler if handler is None: - try: - handler = exception.handle # type: ignore - except AttributeError: + handler = getattr(exception, 'handle', None) + if handler is None: raise AttributeError( - 'handler must either be specified ' - 'explicitly or defined as a static' - 'method named "handle" that is a ' - 'member of the given exception class.' + 'handler must either be specified explicitly or defined as a ' + 'static method named "handle" that is a member of the given ' + 'exception class.' ) # TODO(vytas): Remove this shimming in a future Falcon version. @@ -857,11 +934,11 @@ def handler(req, resp, ex, params): ) handler = wrap_old_handler(handler) - exception_tuple: tuple + exception_tuple: Tuple[type[BaseException], ...] try: - exception_tuple = tuple(exception) # type: ignore + exception_tuple = tuple(exception) # type: ignore[arg-type] except TypeError: - exception_tuple = (exception,) + exception_tuple = (exception,) # type: ignore[assignment] for exc in exception_tuple: if not issubclass(exc, BaseException): @@ -869,7 +946,7 @@ def handler(req, resp, ex, params): self._error_handlers[exc] = handler - def set_error_serializer(self, serializer: ErrorSerializer): + def set_error_serializer(self, serializer: ErrorSerializer) -> None: """Override the default serializer for instances of :class:`~.HTTPError`. When a responder raises an instance of :class:`~.HTTPError`, @@ -892,7 +969,9 @@ def set_error_serializer(self, serializer: ErrorSerializer): such as `to_json()` and `to_dict()`, that can be used from within custom serializers. For example:: - def my_serializer(req, resp, exception): + def my_serializer( + req: Request, resp: Response, exception: HTTPError + ) -> None: representation = None preferred = req.client_prefers((falcon.MEDIA_YAML, falcon.MEDIA_JSON)) @@ -921,14 +1000,21 @@ def my_serializer(req, resp, exception): # Helpers that require self # ------------------------------------------------------------------------ - def _prepare_middleware(self, middleware=None, independent_middleware=False): + def _prepare_middleware( + self, middleware: List[object], independent_middleware: bool = False + ) -> helpers.PreparedMiddlewareResult: return helpers.prepare_middleware( middleware=middleware, independent_middleware=independent_middleware ) def _get_responder( self, req: Request - ) -> Tuple[Callable, dict, object, Optional[str]]: + ) -> Tuple[ + Union[ResponderCallable, AsgiResponderCallable, AsgiResponderWsCallable], + Dict[str, Any], + object, + Optional[str], + ]: """Search routes for a matching responder. Args: @@ -964,7 +1050,7 @@ def _get_responder( # NOTE(kgriffs): Older routers may not return the # template. But for performance reasons they should at # least return None if they don't support it. - resource, method_map, params = route + resource, method_map, params = route # type: ignore[misc] else: # NOTE(kgriffs): Older routers may indicate that no route # was found by returning (None, None, None). Therefore, we @@ -990,7 +1076,7 @@ def _get_responder( m = matcher.match(path) if m: if is_sink: - params = m.groupdict() + params = m.groupdict() # type: ignore[union-attr] responder = obj break @@ -1028,17 +1114,23 @@ def _compose_error_response( self._serialize_error(req, resp, error) - def _http_status_handler(self, req, resp, status, params): + def _http_status_handler( + self, req: Request, resp: Response, status: HTTPStatus, params: Dict[str, Any] + ) -> None: self._compose_status_response(req, resp, status) - def _http_error_handler(self, req, resp, error, params): + def _http_error_handler( + self, req: Request, resp: Response, error: HTTPError, params: Dict[str, Any] + ) -> None: self._compose_error_response(req, resp, error) - def _python_error_handler(self, req, resp, error, params): + def _python_error_handler( + self, req: Request, resp: Response, error: BaseException, params: Dict[str, Any] + ) -> None: req.log_error(traceback.format_exc()) self._compose_error_response(req, resp, HTTPInternalServerError()) - def _find_error_handler(self, ex): + def _find_error_handler(self, ex: BaseException) -> Optional[ErrorHandler]: # NOTE(csojinb): The `__mro__` class attribute returns the method # resolution order tuple, i.e. the complete linear inheritance chain # ``(type(ex), ..., object)``. For a valid exception class, the last @@ -1053,8 +1145,11 @@ def _find_error_handler(self, ex): if handler is not None: return handler + return None - def _handle_exception(self, req, resp, ex, params): + def _handle_exception( + self, req: Request, resp: Response, ex: BaseException, params: Dict[str, Any] + ) -> bool: """Handle an exception raised from mw or a responder. Args: @@ -1093,7 +1188,11 @@ def _handle_exception(self, req, resp, ex, params): # PERF(kgriffs): Moved from api_helpers since it is slightly faster # to call using self, and this function is called for most # requests. - def _get_body(self, resp, wsgi_file_wrapper=None): + def _get_body( + self, + resp: Response, + wsgi_file_wrapper: Optional[Callable[[IO[bytes], int], Iterable[bytes]]] = None, + ) -> Tuple[Iterable[bytes], Optional[int]]: """Convert resp content into an iterable as required by PEP 333. Args: @@ -1116,7 +1215,7 @@ def _get_body(self, resp, wsgi_file_wrapper=None): """ - data = resp.render_body() + data: Optional[bytes] = resp.render_body() if data is not None: return [data], len(data) @@ -1143,11 +1242,11 @@ def _get_body(self, resp, wsgi_file_wrapper=None): return [], 0 - def _update_sink_and_static_routes(self): + def _update_sink_and_static_routes(self) -> None: if self._sink_before_static_route: - self._sink_and_static_routes = tuple(self._sinks + self._static_routes) + self._sink_and_static_routes = tuple(self._sinks + self._static_routes) # type: ignore[operator] else: - self._sink_and_static_routes = tuple(self._static_routes + self._sinks) + self._sink_and_static_routes = tuple(self._static_routes + self._sinks) # type: ignore[operator] # TODO(myusko): This class is a compatibility alias, and should be removed @@ -1166,5 +1265,5 @@ class API(App): @deprecation.deprecated( 'API class may be removed in a future release, use falcon.App instead.' ) - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) diff --git a/falcon/app_helpers.py b/falcon/app_helpers.py index db6d7cc24..bca38a3bc 100644 --- a/falcon/app_helpers.py +++ b/falcon/app_helpers.py @@ -17,7 +17,7 @@ from __future__ import annotations from inspect import iscoroutinefunction -from typing import IO, Iterable, List, Tuple +from typing import IO, Iterable, List, Literal, Optional, overload, Tuple, Union from falcon import util from falcon.constants import MEDIA_JSON @@ -26,6 +26,14 @@ from falcon.errors import HTTPError from falcon.request import Request from falcon.response import Response +from falcon.typing import AsgiProcessRequestMethod as APRequest +from falcon.typing import AsgiProcessRequestWsMethod +from falcon.typing import AsgiProcessResourceMethod as APResource +from falcon.typing import AsgiProcessResourceWsMethod +from falcon.typing import AsgiProcessResponseMethod as APResponse +from falcon.typing import ProcessRequestMethod as PRequest +from falcon.typing import ProcessResourceMethod as PResource +from falcon.typing import ProcessResponseMethod as PResponse from falcon.util.sync import _wrap_non_coroutine_unsafe __all__ = ( @@ -35,10 +43,46 @@ 'CloseableStreamIterator', ) +PreparedMiddlewareResult = Tuple[ + Union[ + Tuple[PRequest, ...], Tuple[Tuple[Optional[PRequest], Optional[PResource]], ...] + ], + Tuple[PResource, ...], + Tuple[PResponse, ...], +] +AsyncPreparedMiddlewareResult = Tuple[ + Union[ + Tuple[APRequest, ...], + Tuple[Tuple[Optional[APRequest], Optional[APResource]], ...], + ], + Tuple[APResource, ...], + Tuple[APResponse, ...], +] + + +@overload +def prepare_middleware( + middleware: Iterable, independent_middleware: bool = ..., asgi: Literal[False] = ... +) -> PreparedMiddlewareResult: ... + + +@overload +def prepare_middleware( + middleware: Iterable, independent_middleware: bool = ..., *, asgi: Literal[True] +) -> AsyncPreparedMiddlewareResult: ... + +@overload def prepare_middleware( - middleware: Iterable, independent_middleware: bool = False, asgi: bool = False -) -> Tuple[tuple, tuple, tuple]: + middleware: Iterable, independent_middleware: bool = ..., asgi: bool = ... +) -> Union[PreparedMiddlewareResult, AsyncPreparedMiddlewareResult]: ... + + +def prepare_middleware( + middleware: Iterable[object], + independent_middleware: bool = False, + asgi: bool = False, +) -> Union[PreparedMiddlewareResult, AsyncPreparedMiddlewareResult]: """Check middleware interfaces and prepare the methods for request handling. Note: @@ -60,9 +104,14 @@ def prepare_middleware( # PERF(kgriffs): do getattr calls once, in advance, so we don't # have to do them every time in the request path. - request_mw: List = [] - resource_mw: List = [] - response_mw: List = [] + request_mw: Union[ + List[PRequest], + List[Tuple[Optional[PRequest], Optional[PResource]]], + List[APRequest], + List[Tuple[Optional[APRequest], Optional[APResource]]], + ] = [] + resource_mw: Union[List[APResource], List[PResource]] = [] + response_mw: Union[List[APResponse], List[PResponse]] = [] for component in middleware: # NOTE(kgriffs): Middleware that supports both WSGI and ASGI can @@ -70,22 +119,25 @@ def prepare_middleware( # to distinguish the two. Otherwise, the prefix is unnecessary. if asgi: - process_request = util.get_bound_method( - component, 'process_request_async' - ) or _wrap_non_coroutine_unsafe( - util.get_bound_method(component, 'process_request') + process_request: Union[Optional[APRequest], Optional[PRequest]] = ( + util.get_bound_method(component, 'process_request_async') + or _wrap_non_coroutine_unsafe( + util.get_bound_method(component, 'process_request') + ) ) - process_resource = util.get_bound_method( - component, 'process_resource_async' - ) or _wrap_non_coroutine_unsafe( - util.get_bound_method(component, 'process_resource') + process_resource: Union[Optional[APResource], Optional[PResource]] = ( + util.get_bound_method(component, 'process_resource_async') + or _wrap_non_coroutine_unsafe( + util.get_bound_method(component, 'process_resource') + ) ) - process_response = util.get_bound_method( - component, 'process_response_async' - ) or _wrap_non_coroutine_unsafe( - util.get_bound_method(component, 'process_response') + process_response: Union[Optional[APResponse], Optional[PResponse]] = ( + util.get_bound_method(component, 'process_response_async') + or _wrap_non_coroutine_unsafe( + util.get_bound_method(component, 'process_response') + ) ) for m in (process_request, process_resource, process_response): @@ -143,20 +195,27 @@ def prepare_middleware( # together or separately. if independent_middleware: if process_request: - request_mw.append(process_request) + request_mw.append(process_request) # type: ignore[arg-type] if process_response: - response_mw.insert(0, process_response) + response_mw.insert(0, process_response) # type: ignore[arg-type] else: if process_request or process_response: - request_mw.append((process_request, process_response)) + request_mw.append((process_request, process_response)) # type: ignore[arg-type] if process_resource: - resource_mw.append(process_resource) + resource_mw.append(process_resource) # type: ignore[arg-type] + + return tuple(request_mw), tuple(resource_mw), tuple(response_mw) # type: ignore[return-value] - return (tuple(request_mw), tuple(resource_mw), tuple(response_mw)) +AsyncPreparedMiddlewareWsResult = Tuple[ + Tuple[AsgiProcessRequestWsMethod, ...], Tuple[AsgiProcessResourceWsMethod, ...] +] -def prepare_middleware_ws(middleware: Iterable) -> Tuple[list, list]: + +def prepare_middleware_ws( + middleware: Iterable[object], +) -> AsyncPreparedMiddlewareWsResult: """Check middleware interfaces and prepare WebSocket methods for request handling. Note: @@ -174,8 +233,11 @@ def prepare_middleware_ws(middleware: Iterable) -> Tuple[list, list]: # PERF(kgriffs): do getattr calls once, in advance, so we don't # have to do them every time in the request path. - request_mw = [] - resource_mw = [] + request_mw: List[AsgiProcessRequestWsMethod] = [] + resource_mw: List[AsgiProcessResourceWsMethod] = [] + + process_request_ws: Optional[AsgiProcessRequestWsMethod] + process_resource_ws: Optional[AsgiProcessResourceWsMethod] for component in middleware: process_request_ws = util.get_bound_method(component, 'process_request_ws') @@ -201,7 +263,7 @@ def prepare_middleware_ws(middleware: Iterable) -> Tuple[list, list]: if process_resource_ws: resource_mw.append(process_resource_ws) - return request_mw, resource_mw + return tuple(request_mw), tuple(resource_mw) def default_serialize_error(req: Request, resp: Response, exception: HTTPError) -> None: @@ -283,7 +345,7 @@ class CloseableStreamIterator: block_size (int): Number of bytes to read per iteration. """ - def __init__(self, stream: IO, block_size: int) -> None: + def __init__(self, stream: IO[bytes], block_size: int) -> None: self._stream = stream self._block_size = block_size diff --git a/falcon/asgi/_asgi_helpers.py b/falcon/asgi/_asgi_helpers.py index ce298abf2..9bbd12e88 100644 --- a/falcon/asgi/_asgi_helpers.py +++ b/falcon/asgi/_asgi_helpers.py @@ -12,15 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import functools import inspect +from typing import Any, Callable, Optional, TypeVar from falcon.errors import UnsupportedError from falcon.errors import UnsupportedScopeError @functools.lru_cache(maxsize=16) -def _validate_asgi_scope(scope_type, spec_version, http_version): +def _validate_asgi_scope( + scope_type: str, spec_version: Optional[str], http_version: str +) -> str: if scope_type == 'http': spec_version = spec_version or '2.0' if not spec_version.startswith('2.'): @@ -60,7 +65,10 @@ def _validate_asgi_scope(scope_type, spec_version, http_version): raise UnsupportedScopeError(f'The ASGI "{scope_type}" scope type is not supported.') -def _wrap_asgi_coroutine_func(asgi_impl): +_C = TypeVar('_C', bound=Callable[..., Any]) + + +def _wrap_asgi_coroutine_func(asgi_impl: _C) -> _C: """Wrap an ASGI application in another coroutine. This utility is used to wrap the cythonized ``App.__call__`` in order to @@ -84,10 +92,10 @@ def _wrap_asgi_coroutine_func(asgi_impl): # "self" parameter. # NOTE(vytas): Intentionally not using functools.wraps as it erroneously # inherits the cythonized method's traits. - async def __call__(self, scope, receive, send): + async def __call__(self: Any, scope: Any, receive: Any, send: Any) -> None: await asgi_impl(self, scope, receive, send) if inspect.iscoroutinefunction(asgi_impl): return asgi_impl - return __call__ + return __call__ # type: ignore[return-value] diff --git a/falcon/asgi/app.py b/falcon/asgi/app.py index 472ded751..4fe7b7db3 100644 --- a/falcon/asgi/app.py +++ b/falcon/asgi/app.py @@ -18,11 +18,32 @@ from inspect import isasyncgenfunction from inspect import iscoroutinefunction import traceback -from typing import Awaitable, Callable, Iterable, Optional, Type, Union - +from typing import ( + Any, + Awaitable, + Callable, + ClassVar, + Dict, + Iterable, + List, + Optional, + overload, + Tuple, + Type, + TYPE_CHECKING, + TypeVar, + Union, +) + +from falcon import constants +from falcon import responders +from falcon import routing import falcon.app +from falcon.app_helpers import AsyncPreparedMiddlewareResult +from falcon.app_helpers import AsyncPreparedMiddlewareWsResult from falcon.app_helpers import prepare_middleware from falcon.app_helpers import prepare_middleware_ws +from falcon.asgi_spec import AsgiSendMsg from falcon.asgi_spec import EventType from falcon.asgi_spec import WSCloseCode from falcon.constants import _UNSET @@ -33,9 +54,14 @@ from falcon.http_error import HTTPError from falcon.http_status import HTTPStatus from falcon.media.multipart import MultipartFormHandler -import falcon.routing -from falcon.typing import ErrorHandler +from falcon.typing import AsgiErrorHandler +from falcon.typing import AsgiReceive +from falcon.typing import AsgiResponderCallable +from falcon.typing import AsgiResponderWsCallable +from falcon.typing import AsgiSend +from falcon.typing import AsgiSinkCallable from falcon.typing import SinkPrefix +from falcon.util import get_argnames from falcon.util.misc import is_python_func from falcon.util.sync import _should_wrap_non_coroutines from falcon.util.sync import _wrap_non_coroutine_unsafe @@ -56,19 +82,20 @@ # TODO(vytas): Clean up these foul workarounds before the 4.0 release. -MultipartFormHandler._ASGI_MULTIPART_FORM = MultipartForm # type: ignore +MultipartFormHandler._ASGI_MULTIPART_FORM = MultipartForm -_EVT_RESP_EOF = {'type': EventType.HTTP_RESPONSE_BODY} +_EVT_RESP_EOF: AsgiSendMsg = {'type': EventType.HTTP_RESPONSE_BODY} _BODILESS_STATUS_CODES = frozenset([100, 101, 204, 304]) _TYPELESS_STATUS_CODES = frozenset([204, 304]) _FALLBACK_WS_ERROR_CODE = 3011 +_BE = TypeVar('_BE', bound=BaseException) class App(falcon.app.App): - """The main entry point into a Falcon-based ASGI app. + '''The main entry point into a Falcon-based ASGI app. Each App instance provides a callable `ASGI `_ interface @@ -113,9 +140,11 @@ class App(falcon.app.App): would like to handle; Falcon simply skips over any missing middleware methods:: - class ExampleComponent: - async def process_startup(self, scope, event): - \"\"\"Process the ASGI lifespan startup event. + class ExampleMiddleware: + async def process_startup( + self, scope: dict[str, Any], event: dict[str, Any] + ) -> None: + """Process the ASGI lifespan startup event. Invoked when the server is ready to start up and receive connections, but before it has started to @@ -132,10 +161,12 @@ async def process_startup(self, scope, event): for the duration of the event loop. event (dict): The ASGI event dictionary for the startup event. - \"\"\" + """ - async def process_shutdown(self, scope, event): - \"\"\"Process the ASGI lifespan shutdown event. + async def process_shutdown( + self, scope: dict[str, Any], event: dict[str, Any] + ) -> None: + """Process the ASGI lifespan shutdown event. Invoked when the server has stopped accepting connections and closed all active connections. @@ -151,10 +182,12 @@ async def process_shutdown(self, scope, event): for the duration of the event loop. event (dict): The ASGI event dictionary for the shutdown event. - \"\"\" + """ - async def process_request(self, req, resp): - \"\"\"Process the request before routing it. + async def process_request( + self, req: Request, resp: Response + ) -> None: + """Process the request before routing it. Note: Because Falcon routes each request based on @@ -167,10 +200,16 @@ async def process_request(self, req, resp): routed to an on_* responder method. resp: Response object that will be routed to the on_* responder. - \"\"\" + """ - async def process_resource(self, req, resp, resource, params): - \"\"\"Process the request and resource *after* routing. + async def process_resource( + self, + req: Request, + resp: Response, + resource: object, + params: dict[str, Any], + ) -> None: + """Process the request and resource *after* routing. Note: This method is only called when the request matches @@ -189,10 +228,16 @@ async def process_resource(self, req, resp, resource, params): template fields, that will be passed to the resource's responder method as keyword arguments. - \"\"\" + """ - async def process_response(self, req, resp, resource, req_succeeded) - \"\"\"Post-processing of the response (after routing). + async def process_response( + self, + req: Request, + resp: Response, + resource: object, + req_succeeded: bool + ) -> None: + """Post-processing of the response (after routing). Args: req: Request object. @@ -203,7 +248,51 @@ async def process_response(self, req, resp, resource, req_succeeded) req_succeeded: True if no exceptions were raised while the framework processed and routed the request; otherwise False. - \"\"\" + """ + + # WebSocket methods + async def process_request_ws( + self, req: Request, ws: WebSocket + ) -> None: + """Process a WebSocket handshake request before routing it. + + Note: + Because Falcon routes each request based on req.path, a + request can be effectively re-routed by setting that + attribute to a new value from within process_request(). + + Args: + req: Request object that will eventually be + passed into an on_websocket() responder method. + ws: The WebSocket object that will be passed into + on_websocket() after routing. + """ + + async def process_resource_ws( + self, + req: Request, + ws: WebSocket, + resource: object, + params: dict[str, Any], + ) -> None: + """Process a WebSocket handshake request after routing. + + Note: + This method is only called when the request matches + a route to a resource. + + Args: + req: Request object that will be passed to the + routed responder. + ws: WebSocket object that will be passed to the + routed responder. + resource: Resource object to which the request was + routed. + params: A dict-like object representing any additional + params derived from the route's URI template fields, + that will be passed to the resource's responder + method as keyword arguments. + """ (See also: :ref:`Middleware `) @@ -239,39 +328,59 @@ async def process_response(self, req, resp, resource, req_succeeded) sink_before_static_route (bool): Indicates if the sinks should be processed before (when ``True``) or after (when ``False``) the static routes. This has an effect only if no route was matched. (default ``True``) + ''' - Attributes: - req_options: A set of behavioral options related to incoming - requests. (See also: :class:`~.RequestOptions`) - resp_options: A set of behavioral options related to outgoing - responses. (See also: :class:`~.ResponseOptions`) - ws_options: A set of behavioral options related to WebSocket - connections. (See also: :class:`~.WebSocketOptions`) - router_options: Configuration options for the router. If a - custom router is in use, and it does not expose any - configurable options, referencing this attribute will raise - an instance of ``AttributeError``. - - (See also: :ref:`CompiledRouterOptions `) - """ - - _STATIC_ROUTE_TYPE = falcon.routing.StaticRouteAsync + _STATIC_ROUTE_TYPE = routing.StaticRouteAsync # NOTE(kgriffs): This makes it easier to tell what we are dealing with # without having to import falcon.asgi. - _ASGI = True + _ASGI: ClassVar[bool] = True - _default_responder_bad_request = falcon.responders.bad_request_async - _default_responder_path_not_found = falcon.responders.path_not_found_async + _default_responder_bad_request: ClassVar[AsgiResponderCallable] = ( + responders.bad_request_async # type: ignore[assignment] + ) + _default_responder_path_not_found: ClassVar[AsgiResponderCallable] = ( + responders.path_not_found_async # type: ignore[assignment] + ) __slots__ = ( '_standard_response_type', + '_middleware_ws', 'ws_options', ) - def __init__(self, *args, request_type=Request, response_type=Response, **kwargs): + _error_handlers: Dict[Type[BaseException], AsgiErrorHandler] # type: ignore[assignment] + _middleware: AsyncPreparedMiddlewareResult # type: ignore[assignment] + _middleware_ws: AsyncPreparedMiddlewareWsResult + _request_type: Type[Request] + _response_type: Type[Response] + + ws_options: WebSocketOptions + """A set of behavioral options related to WebSocket connections. + + See also: :class:`~.WebSocketOptions`. + """ + + def __init__( + self, + media_type: str = constants.DEFAULT_MEDIA_TYPE, + request_type: Type[Request] = Request, + response_type: Type[Response] = Response, + middleware: Union[object, Iterable[object]] = None, + router: Optional[routing.CompiledRouter] = None, + independent_middleware: bool = True, + cors_enable: bool = False, + sink_before_static_route: bool = True, + ) -> None: super().__init__( - *args, request_type=request_type, response_type=response_type, **kwargs + media_type, + request_type, + response_type, + middleware, + router, + independent_middleware, + cors_enable, + sink_before_static_route, ) self.ws_options = WebSocketOptions() @@ -282,31 +391,31 @@ def __init__(self, *args, request_type=Request, response_type=Response, **kwargs ) @_wrap_asgi_coroutine_func - async def __call__( # noqa: C901 + async def __call__( # type: ignore[override] # noqa: C901 self, - scope: dict, - receive: Callable[[], Awaitable[dict]], - send: Callable[[dict], Awaitable[None]], + scope: Dict[str, Any], + receive: AsgiReceive, + send: AsgiSend, ) -> None: # NOTE(kgriffs): The ASGI spec requires the 'type' key to be present. - scope_type = scope['type'] + scope_type: str = scope['type'] # PERF(kgriffs): This should usually be present, so use a # try..except try: - asgi_info = scope['asgi'] + asgi_info: Dict[str, str] = scope['asgi'] except KeyError: # NOTE(kgriffs): According to the ASGI spec, "2.0" is # the default version. asgi_info = scope['asgi'] = {'version': '2.0'} try: - spec_version = asgi_info['spec_version'] + spec_version: Optional[str] = asgi_info['spec_version'] except KeyError: spec_version = None try: - http_version = scope['http_version'] + http_version: str = scope['http_version'] except KeyError: http_version = '1.1' @@ -346,9 +455,8 @@ async def __call__( # noqa: C901 ) resp = self._response_type(options=self.resp_options) - resource = None - responder: Optional[Callable] = None - params: dict = {} + resource: Optional[object] = None + params: Dict[str, Any] = {} dependent_mw_resp_stack: list = [] mw_req_stack, mw_rsrc_stack, mw_resp_stack = self._middleware @@ -367,14 +475,14 @@ async def __call__( # noqa: C901 # response middleware after request middleware succeeds. if self._independent_middleware: for process_request in mw_req_stack: - await process_request(req, resp) + await process_request(req, resp) # type: ignore[operator] if resp.complete: break else: - for process_request, process_response in mw_req_stack: + for process_request, process_response in mw_req_stack: # type: ignore[misc, assignment] if process_request and not resp.complete: - await process_request(req, resp) + await process_request(req, resp) # type: ignore[operator] if process_response: dependent_mw_resp_stack.insert(0, process_response) @@ -387,7 +495,8 @@ async def __call__( # noqa: C901 # next-hop child resource. In that case, the object # being asked to dispatch to its child will raise an # HTTP exception signaling the problem, e.g. a 404. - responder, params, resource, req.uri_template = self._get_responder(req) + responder: AsgiResponderCallable + responder, params, resource, req.uri_template = self._get_responder(req) # type: ignore[assignment] except Exception as ex: if not await self._handle_exception(req, resp, ex, params): @@ -410,7 +519,7 @@ async def __call__( # noqa: C901 break if not resp.complete: - await responder(req, resp, **params) # type: ignore + await responder(req, resp, **params) req_succeeded = True @@ -429,7 +538,7 @@ async def __call__( # noqa: C901 req_succeeded = False - data = b'' + data: Optional[bytes] = b'' try: # NOTE(vytas): It is only safe to inline Response.render_body() @@ -480,8 +589,8 @@ async def __call__( # noqa: C901 req_succeeded = False - resp_status = resp.status_code - default_media_type = self.resp_options.default_media_type + resp_status: int = resp.status_code + default_media_type: Optional[str] = self.resp_options.default_media_type if req.method == 'HEAD' or resp_status in _BODILESS_STATUS_CODES: # @@ -546,7 +655,7 @@ async def __call__( # noqa: C901 # NOTE(kgriffs): This must be done in a separate task because # receive() can block for some time (until the connection is # actually closed). - async def watch_disconnect(): + async def watch_disconnect() -> None: while True: received_event = await receive() if received_event['type'] == EventType.HTTP_DISCONNECT: @@ -724,16 +833,16 @@ async def watch_disconnect(): if resp._registered_callbacks: self._schedule_callbacks(resp) - def add_route(self, uri_template: str, resource: object, **kwargs): + def add_route(self, uri_template: str, resource: object, **kwargs: Any) -> None: # NOTE(kgriffs): Inject an extra kwarg so that the compiled router # will know to validate the responder methods to make sure they # are async coroutines. kwargs['_asgi'] = True super().add_route(uri_template, resource, **kwargs) - add_route.__doc__ = falcon.app.App.add_route.__doc__ + add_route.__doc__ = falcon.app.App.add_route.__doc__ # NOTE: not really required - def add_sink(self, sink: Callable, prefix: SinkPrefix = r'/'): + def add_sink(self, sink: AsgiSinkCallable, prefix: SinkPrefix = r'/') -> None: # type: ignore[override] if not iscoroutinefunction(sink) and is_python_func(sink): if _should_wrap_non_coroutines(): sink = wrap_sync_to_async(sink) @@ -743,15 +852,29 @@ def add_sink(self, sink: Callable, prefix: SinkPrefix = r'/'): 'in order to be used safely with an ASGI app.' ) - super().add_sink(sink, prefix=prefix) + super().add_sink(sink, prefix=prefix) # type: ignore[arg-type] + + add_sink.__doc__ = falcon.app.App.add_sink.__doc__ # NOTE: not really required - add_sink.__doc__ = falcon.app.App.add_sink.__doc__ + @overload # type: ignore[override] + def add_error_handler( + self, + exception: Type[_BE], + handler: Callable[[Request, Response, _BE, Dict[str, Any]], Awaitable[None]], + ) -> None: ... + @overload def add_error_handler( self, exception: Union[Type[BaseException], Iterable[Type[BaseException]]], - handler: Optional[ErrorHandler] = None, - ): + handler: Optional[AsgiErrorHandler] = None, + ) -> None: ... + + def add_error_handler( # type: ignore[misc] + self, + exception: Union[Type[BaseException], Iterable[Type[BaseException]]], + handler: Optional[AsgiErrorHandler] = None, + ) -> None: """Register a handler for one or more exception types. Error handlers may be registered for any exception type, including @@ -818,7 +941,6 @@ def add_error_handler( type(s), the associated handler will be called. Either a single type or an iterable of types may be specified. - Keyword Args: handler (callable): A coroutine function taking the form:: @@ -851,14 +973,12 @@ async def handle(req, resp, ex, params): """ if handler is None: - try: - handler = exception.handle # type: ignore - except AttributeError: + handler = getattr(exception, 'handle', None) + if handler is None: raise AttributeError( - 'handler must either be specified ' - 'explicitly or defined as a static' - 'method named "handle" that is a ' - 'member of the given exception class.' + 'handler must either be specified explicitly or defined as a ' + 'static method named "handle" that is a member of the given ' + 'exception class.' ) # NOTE(vytas): Do not shoot ourselves in the foot in case error @@ -868,7 +988,7 @@ async def handle(req, resp, ex, params): self._http_error_handler, self._python_error_handler, ): - handler = _wrap_non_coroutine_unsafe(handler) + handler = _wrap_non_coroutine_unsafe(handler) # type: ignore[assignment] # NOTE(kgriffs): iscoroutinefunction() always returns False # for cythonized functions. @@ -881,24 +1001,25 @@ async def handle(req, resp, ex, params): 'The handler must be an awaitable coroutine function in order ' 'to be used safely with an ASGI app.' ) + handler_callable: AsgiErrorHandler = handler - exception_tuple: tuple + exception_tuple: Tuple[type[BaseException], ...] try: - exception_tuple = tuple(exception) # type: ignore + exception_tuple = tuple(exception) # type: ignore[arg-type] except TypeError: - exception_tuple = (exception,) + exception_tuple = (exception,) # type: ignore[assignment] for exc in exception_tuple: if not issubclass(exc, BaseException): raise TypeError('"exception" must be an exception type.') - self._error_handlers[exc] = handler + self._error_handlers[exc] = handler_callable # ------------------------------------------------------------------------ # Helper methods # ------------------------------------------------------------------------ - def _schedule_callbacks(self, resp): + def _schedule_callbacks(self, resp: Response) -> None: callbacks = resp._registered_callbacks # PERF(vytas): resp._registered_callbacks is already checked directly # to shave off a function call since this is a hot/critical code path. @@ -907,13 +1028,15 @@ def _schedule_callbacks(self, resp): loop = asyncio.get_running_loop() - for cb, is_async in callbacks: + for cb, is_async in callbacks: # type: ignore[attr-defined] if is_async: loop.create_task(cb()) else: loop.run_in_executor(None, cb) - async def _call_lifespan_handlers(self, ver, scope, receive, send): + async def _call_lifespan_handlers( + self, ver: str, scope: Dict[str, Any], receive: AsgiReceive, send: AsgiSend + ) -> None: while True: event = await receive() if event['type'] == 'lifespan.startup': @@ -921,7 +1044,7 @@ async def _call_lifespan_handlers(self, ver, scope, receive, send): # startup, as opposed to repeating them every request. # NOTE(vytas): If missing, 'asgi' is populated in __call__. - asgi_info = scope['asgi'] + asgi_info: Dict[str, str] = scope['asgi'] version = asgi_info.get('version', '2.0 (implicit)') if not version.startswith('3.'): await send( @@ -981,7 +1104,9 @@ async def _call_lifespan_handlers(self, ver, scope, receive, send): await send({'type': EventType.LIFESPAN_SHUTDOWN_COMPLETE}) return - async def _handle_websocket(self, ver, scope, receive, send): + async def _handle_websocket( + self, ver: str, scope: Dict[str, Any], receive: AsgiReceive, send: AsgiSend + ) -> None: first_event = await receive() if first_event['type'] != EventType.WS_CONNECT: # NOTE(kgriffs): The handshake was abandoned or this is a message @@ -1007,8 +1132,7 @@ async def _handle_websocket(self, ver, scope, receive, send): self.ws_options.default_close_reasons, ) - on_websocket = None - params = {} + params: Dict[str, Any] = {} request_mw, resource_mw = self._middleware_ws @@ -1016,7 +1140,8 @@ async def _handle_websocket(self, ver, scope, receive, send): for process_request_ws in request_mw: await process_request_ws(req, web_socket) - on_websocket, params, resource, req.uri_template = self._get_responder(req) + on_websocket: AsgiResponderWsCallable + on_websocket, params, resource, req.uri_template = self._get_responder(req) # type: ignore[assignment] # NOTE(kgriffs): If the request did not match any # route, a default responder is returned and the @@ -1035,7 +1160,9 @@ async def _handle_websocket(self, ver, scope, receive, send): if not await self._handle_exception(req, None, ex, params, ws=web_socket): raise - def _prepare_middleware(self, middleware=None, independent_middleware=False): + def _prepare_middleware( # type: ignore[override] + self, middleware: List[object], independent_middleware: bool = False + ) -> AsyncPreparedMiddlewareResult: self._middleware_ws = prepare_middleware_ws(middleware) return prepare_middleware( @@ -1044,11 +1171,18 @@ def _prepare_middleware(self, middleware=None, independent_middleware=False): asgi=True, ) - async def _http_status_handler(self, req, resp, status, params, ws=None): + async def _http_status_handler( # type: ignore[override] + self, + req: Request, + resp: Optional[Response], + status: HTTPStatus, + params: Dict[str, Any], + ws: Optional[WebSocket] = None, + ) -> None: if resp: self._compose_status_response(req, resp, status) elif ws: - code = http_status_to_ws_code(status.status) + code = http_status_to_ws_code(status.status_code) falcon._logger.error( '[FALCON] HTTPStatus %s raised while handling WebSocket. ' 'Closing with code %s', @@ -1059,7 +1193,14 @@ async def _http_status_handler(self, req, resp, status, params, ws=None): else: raise NotImplementedError('resp or ws expected') - async def _http_error_handler(self, req, resp, error, params, ws=None): + async def _http_error_handler( # type: ignore[override] + self, + req: Request, + resp: Optional[Response], + error: HTTPError, + params: Dict[str, Any], + ws: Optional[WebSocket] = None, + ) -> None: if resp: self._compose_error_response(req, resp, error) elif ws: @@ -1074,7 +1215,14 @@ async def _http_error_handler(self, req, resp, error, params, ws=None): else: raise NotImplementedError('resp or ws expected') - async def _python_error_handler(self, req, resp, error, params, ws=None): + async def _python_error_handler( # type: ignore[override] + self, + req: Request, + resp: Optional[Response], + error: BaseException, + params: Dict[str, Any], + ws: Optional[WebSocket] = None, + ) -> None: falcon._logger.error('[FALCON] Unhandled exception in ASGI app', exc_info=error) if resp: @@ -1084,13 +1232,35 @@ async def _python_error_handler(self, req, resp, error, params, ws=None): else: raise NotImplementedError('resp or ws expected') - async def _ws_disconnected_error_handler(self, req, resp, error, params, ws): + async def _ws_disconnected_error_handler( + self, + req: Request, + resp: Optional[Response], + error: WebSocketDisconnected, + params: Dict[str, Any], + ws: Optional[WebSocket] = None, + ) -> None: + assert resp is None + assert ws is not None falcon._logger.debug( '[FALCON] WebSocket client disconnected with code %i', error.code ) await self._ws_cleanup_on_error(ws) - async def _handle_exception(self, req, resp, ex, params, ws=None): + if TYPE_CHECKING: + + def _find_error_handler( # type: ignore[override] + self, ex: BaseException + ) -> Optional[AsgiErrorHandler]: ... + + async def _handle_exception( # type: ignore[override] + self, + req: Request, + resp: Optional[Response], + ex: BaseException, + params: Dict[str, Any], + ws: Optional[WebSocket] = None, + ) -> bool: """Handle an exception raised from mw or a responder. Args: @@ -1121,7 +1291,7 @@ async def _handle_exception(self, req, resp, ex, params, ws=None): try: kwargs = {} - if ws and 'ws' in falcon.util.get_argnames(err_handler): + if ws and 'ws' in get_argnames(err_handler): kwargs['ws'] = ws await err_handler(req, resp, ex, params, **kwargs) @@ -1139,7 +1309,7 @@ async def _handle_exception(self, req, resp, ex, params, ws=None): # handlers. return False - async def _ws_cleanup_on_error(self, ws): + async def _ws_cleanup_on_error(self, ws: WebSocket) -> None: # NOTE(kgriffs): Attempt to close cleanly on our end try: await ws.close(self.ws_options.error_close_code) diff --git a/falcon/asgi/ws.py b/falcon/asgi/ws.py index 0599a25f3..7971fb868 100644 --- a/falcon/asgi/ws.py +++ b/falcon/asgi/ws.py @@ -1,10 +1,10 @@ +from __future__ import annotations + import asyncio import collections from enum import Enum from typing import ( Any, - Awaitable, - Callable, Deque, Dict, Iterable, @@ -16,9 +16,12 @@ from falcon import errors from falcon import media from falcon import status_codes +from falcon.asgi_spec import AsgiEvent from falcon.asgi_spec import EventType from falcon.asgi_spec import WSCloseCode from falcon.constants import WebSocketPayloadType +from falcon.typing import AsgiReceive +from falcon.typing import AsgiSend from falcon.util import misc _WebSocketState = Enum('_WebSocketState', 'HANDSHAKE ACCEPTED CLOSED') @@ -65,15 +68,15 @@ class WebSocket: def __init__( self, ver: str, - scope: dict, - receive: Callable[[], Awaitable[dict]], - send: Callable[[dict], Awaitable], + scope: Dict[str, Any], + receive: AsgiReceive, + send: AsgiSend, media_handlers: Mapping[ WebSocketPayloadType, Union[media.BinaryBaseHandlerWS, media.TextBaseHandlerWS], ], max_receive_queue: int, - default_close_reasons: Dict[Optional[int], str], + default_close_reasons: Dict[int, str], ): self._supports_accept_headers = ver != '2.0' self._supports_reason = _supports_reason(ver) @@ -653,13 +656,13 @@ class _BufferedReceiver: 'client_disconnected_code', ] - def __init__(self, asgi_receive: Callable[[], Awaitable[dict]], max_queue: int): + def __init__(self, asgi_receive: AsgiReceive, max_queue: int): self._asgi_receive = asgi_receive self._max_queue = max_queue self._loop = asyncio.get_running_loop() - self._messages: Deque[dict] = collections.deque() + self._messages: Deque[AsgiEvent] = collections.deque() self._pop_message_waiter = None self._put_message_waiter = None diff --git a/falcon/asgi_spec.py b/falcon/asgi_spec.py index 9fe12dbb9..3aedf9eda 100644 --- a/falcon/asgi_spec.py +++ b/falcon/asgi_spec.py @@ -16,7 +16,7 @@ from __future__ import annotations -from typing import Any, Mapping +from typing import Any, Dict, Mapping class EventType: @@ -65,3 +65,5 @@ class WSCloseCode: # TODO: use a typed dict for event dicts AsgiEvent = Mapping[str, Any] +# TODO: use a typed dict for send msg dicts +AsgiSendMsg = Dict[str, Any] diff --git a/falcon/http_status.py b/falcon/http_status.py index df7e0d455..1a591fcf4 100644 --- a/falcon/http_status.py +++ b/falcon/http_status.py @@ -40,19 +40,21 @@ class HTTPStatus(Exception): headers (dict): Extra headers to add to the response. text (str): String representing response content. Falcon will encode this value as UTF-8 in the response. - - Attributes: - status (Union[str,int]): The HTTP status line or integer code for - the status that this exception represents. - status_code (int): HTTP status code normalized from :attr:`status`. - headers (dict): Extra headers to add to the response. - text (str): String representing response content. Falcon will encode - this value as UTF-8 in the response. - """ __slots__ = ('status', 'headers', 'text') + status: ResponseStatus + """The HTTP status line or integer code for the status that this exception + represents. + """ + headers: Optional[HeaderList] + """Extra headers to add to the response.""" + text: Optional[str] + """String representing response content. + Falcon will encode this value as UTF-8 in the response. + """ + def __init__( self, status: ResponseStatus, @@ -65,10 +67,11 @@ def __init__( @property def status_code(self) -> int: + """HTTP status code normalized from :attr:`status`.""" return http_status_to_code(self.status) - @property # type: ignore - def body(self): + @property + def body(self) -> None: raise AttributeRemovedError( 'The body attribute is no longer supported. ' 'Please use the text attribute instead.' diff --git a/falcon/inspect.py b/falcon/inspect.py index 9aac44cb0..6d221f713 100644 --- a/falcon/inspect.py +++ b/falcon/inspect.py @@ -189,7 +189,7 @@ def inspect_middleware(app: App) -> 'MiddlewareInfo': current = [] for method in stack: _, name = _get_source_info_and_name(method) - cls = type(method.__self__) + cls = type(method.__self__) # type: ignore[union-attr] _, cls_name = _get_source_info_and_name(cls) current.append(MiddlewareTreeItemInfo(name, cls_name)) type_infos.append(current) @@ -201,12 +201,12 @@ def inspect_middleware(app: App) -> 'MiddlewareInfo': fns = app_helpers.prepare_middleware([m], True, app._ASGI) class_source_info, cls_name = _get_source_info_and_name(type(m)) methods = [] - for method, name in zip(fns, names): + for method, name in zip(fns, names): # type: ignore[assignment] if method: - real_func = method[0] + real_func = method[0] # type: ignore[index] source_info = _get_source_info(real_func) assert source_info - methods.append(MiddlewareMethodInfo(real_func.__name__, source_info)) + methods.append(MiddlewareMethodInfo(real_func.__name__, source_info)) # type: ignore[union-attr] assert class_source_info m_info = MiddlewareClassInfo(cls_name, class_source_info, methods) middlewareClasses.append(m_info) diff --git a/falcon/media/base.py b/falcon/media/base.py index 320b92bd0..70ceea776 100644 --- a/falcon/media/base.py +++ b/falcon/media/base.py @@ -6,7 +6,9 @@ from falcon.constants import MEDIA_JSON from falcon.typing import AsyncReadableIO +from falcon.typing import DeserializeSync from falcon.typing import ReadableIO +from falcon.typing import SerializeSync class BaseHandler(metaclass=abc.ABCMeta): @@ -19,10 +21,10 @@ class BaseHandler(metaclass=abc.ABCMeta): # might make it part of the public interface for use by custom # media type handlers. - _serialize_sync = None + _serialize_sync: Optional[SerializeSync] = None """Override to provide a synchronous serialization method that takes an object.""" - _deserialize_sync = None + _deserialize_sync: Optional[DeserializeSync] = None """Override to provide a synchronous deserialization method that takes a byte string.""" diff --git a/falcon/media/handlers.py b/falcon/media/handlers.py index 7b368202d..e37d5e3b8 100644 --- a/falcon/media/handlers.py +++ b/falcon/media/handlers.py @@ -4,7 +4,6 @@ import functools from typing import ( Any, - Callable, cast, Dict, Literal, @@ -29,6 +28,8 @@ from falcon.media.multipart import MultipartFormHandler from falcon.media.multipart import MultipartParseOptions from falcon.media.urlencoded import URLEncodedFormHandler +from falcon.typing import DeserializeSync +from falcon.typing import SerializeSync from falcon.util import deprecation from falcon.util import misc from falcon.vendor import mimeparse @@ -54,9 +55,7 @@ def _raise(self, *args: Any, **kwargs: Any) -> NoReturn: _ResolverMethodReturnTuple = Tuple[ - BaseHandler, - Optional[Callable[[Any, Optional[str]], bytes]], - Optional[Callable[[bytes], Any]], + BaseHandler, Optional[SerializeSync], Optional[DeserializeSync] ] diff --git a/falcon/media/json.py b/falcon/media/json.py index f3f2cee1d..cf0111e82 100644 --- a/falcon/media/json.py +++ b/falcon/media/json.py @@ -268,4 +268,4 @@ def deserialize(self, payload: str) -> object: return self._loads(payload) -http_error._DEFAULT_JSON_HANDLER = _DEFAULT_JSON_HANDLER = JSONHandler() # type: ignore +http_error._DEFAULT_JSON_HANDLER = _DEFAULT_JSON_HANDLER = JSONHandler() diff --git a/falcon/media/multipart.py b/falcon/media/multipart.py index 901cbe67e..4e08b5306 100644 --- a/falcon/media/multipart.py +++ b/falcon/media/multipart.py @@ -17,7 +17,7 @@ from __future__ import annotations import re -from typing import ClassVar, TYPE_CHECKING +from typing import Any, ClassVar, Dict, Optional, Tuple, Type, TYPE_CHECKING from urllib.parse import unquote_to_bytes from falcon import errors @@ -29,6 +29,7 @@ from falcon.util.mediatypes import parse_header if TYPE_CHECKING: + from falcon.asgi.multipart import MultipartForm as AsgiMultipartForm from falcon.media import Handlers # TODO(vytas): @@ -189,11 +190,11 @@ class BodyPart: decoded_text = await part.text """ - _content_disposition = None - _data = None - _filename = None - _media = None - _name = None + _content_disposition: Optional[Tuple[str, Dict[str, str]]] = None + _data: Optional[bytes] = None + _filename: Optional[str] = None + _media: Optional[Any] = None + _name: Optional[str] = None def __init__(self, stream, headers, parse_options): self.stream = stream @@ -488,7 +489,7 @@ class MultipartFormHandler(BaseHandler): See also: :ref:`multipart_parser_conf`. """ - _ASGI_MULTIPART_FORM = None + _ASGI_MULTIPART_FORM: ClassVar[Type[AsgiMultipartForm]] def __init__(self, parse_options=None): self.parse_options = parse_options or MultipartParseOptions() diff --git a/falcon/response.py b/falcon/response.py index 2a0a3e834..bc676c31a 100644 --- a/falcon/response.py +++ b/falcon/response.py @@ -19,7 +19,7 @@ from datetime import timezone import functools import mimetypes -from typing import Dict, Optional +from typing import Dict from falcon.constants import _DEFAULT_STATIC_MEDIA_TYPES from falcon.constants import _UNSET @@ -208,14 +208,14 @@ def status_code(self) -> int: def status_code(self, value): self.status = value - @property # type: ignore + @property def body(self): raise AttributeRemovedError( 'The body attribute is no longer supported. ' 'Please use the text attribute instead.' ) - @body.setter # type: ignore + @body.setter def body(self, value): raise AttributeRemovedError( 'The body attribute is no longer supported. ' @@ -1233,7 +1233,7 @@ class ResponseOptions: This can make testing easier by not requiring HTTPS. Note, however, that this setting can be overridden via :meth:`~.Response.set_cookie()`'s ``secure`` kwarg. """ - default_media_type: Optional[str] + default_media_type: str """The default Internet media type (RFC 2046) to use when rendering a response, when the Content-Type header is not set explicitly. diff --git a/falcon/routing/compiled.py b/falcon/routing/compiled.py index 961af5a1a..443d0d4f3 100644 --- a/falcon/routing/compiled.py +++ b/falcon/routing/compiled.py @@ -38,6 +38,7 @@ from falcon.routing import converters from falcon.routing.util import map_http_methods from falcon.routing.util import set_default_responders +from falcon.typing import MethodDict from falcon.util.misc import is_python_func from falcon.util.sync import _should_wrap_non_coroutines from falcon.util.sync import wrap_sync_to_async @@ -46,7 +47,6 @@ from falcon import Request _CxElement = Union['_CxParent', '_CxChild'] - _MethodDict = Dict[str, Callable] _TAB_STR = ' ' * 4 _FIELD_PATTERN = re.compile( @@ -135,7 +135,7 @@ def finder_src(self) -> str: self.find('/') return self._finder_src - def map_http_methods(self, resource: object, **kwargs: Any) -> _MethodDict: + def map_http_methods(self, resource: object, **kwargs: Any) -> MethodDict: """Map HTTP methods (e.g., GET, POST) to methods of a resource object. This method is called from :meth:`~.add_route` and may be overridden to @@ -309,7 +309,7 @@ def insert(nodes: List[CompiledRouterNode], path_index: int = 0): # to multiple classes, since the symbol is imported only for type check. def find( self, uri: str, req: Optional['Request'] = None - ) -> Optional[Tuple[object, Optional[_MethodDict], Dict[str, Any], Optional[str]]]: + ) -> Optional[Tuple[object, MethodDict, Dict[str, Any], Optional[str]]]: """Search for a route that matches the given partial URI. Args: @@ -334,7 +334,7 @@ def find( ) if node is not None: - return node.resource, node.method_map, params, node.uri_template + return node.resource, node.method_map or {}, params, node.uri_template else: return None @@ -342,7 +342,7 @@ def find( # Private # ----------------------------------------------------------------- - def _require_coroutine_responders(self, method_map: _MethodDict) -> None: + def _require_coroutine_responders(self, method_map: MethodDict) -> None: for method, responder in method_map.items(): # NOTE(kgriffs): We don't simply wrap non-async functions # since they likely perform relatively long blocking @@ -366,7 +366,7 @@ def let(responder=responder): msg = msg.format(responder) raise TypeError(msg) - def _require_non_coroutine_responders(self, method_map: _MethodDict) -> None: + def _require_non_coroutine_responders(self, method_map: MethodDict) -> None: for method, responder in method_map.items(): # NOTE(kgriffs): We don't simply wrap non-async functions # since they likely perform relatively long blocking @@ -682,7 +682,7 @@ def _compile(self) -> Callable: self._finder_src = '\n'.join(src_lines) - scope: _MethodDict = {} + scope: MethodDict = {} exec(compile(self._finder_src, '', 'exec'), scope) return scope['find'] @@ -742,7 +742,7 @@ class CompiledRouterNode: def __init__( self, raw_segment: str, - method_map: Optional[_MethodDict] = None, + method_map: Optional[MethodDict] = None, resource: Optional[object] = None, uri_template: Optional[str] = None, ): diff --git a/falcon/routing/static.py b/falcon/routing/static.py index 93a076a52..d07af4211 100644 --- a/falcon/routing/static.py +++ b/falcon/routing/static.py @@ -1,13 +1,25 @@ +from __future__ import annotations + import asyncio from functools import partial import io import os +from pathlib import Path import re +from typing import Any, ClassVar, IO, Optional, Pattern, Tuple, TYPE_CHECKING, Union import falcon +if TYPE_CHECKING: + from falcon import asgi + from falcon import Request + from falcon import Response +from falcon.typing import ReadableIO + -def _open_range(file_path, req_range): +def _open_range( + file_path: Union[str, Path], req_range: Optional[Tuple[int, int]] +) -> Tuple[ReadableIO, int, Optional[Tuple[int, int, int]]]: """Open a file for a ranged request. Args: @@ -68,14 +80,14 @@ class _BoundedFile: length (int): Number of bytes that may be read. """ - def __init__(self, fh, length): + def __init__(self, fh: IO[bytes], length: int) -> None: self.fh = fh self.close = fh.close self.remaining = length - def read(self, size=-1): + def read(self, size: Optional[int] = -1) -> bytes: """Read the underlying file object, within the specified bounds.""" - if size < 0: + if size is None or size < 0: size = self.remaining else: size = min(size, self.remaining) @@ -116,16 +128,27 @@ class StaticRoute: """ # NOTE(kgriffs): Don't allow control characters and reserved chars - _DISALLOWED_CHARS_PATTERN = re.compile('[\x00-\x1f\x80-\x9f\ufffd~?<>:*|\'"]') + _DISALLOWED_CHARS_PATTERN: ClassVar[Pattern[str]] = re.compile( + '[\x00-\x1f\x80-\x9f\ufffd~?<>:*|\'"]' + ) # NOTE(vytas): Match the behavior of the underlying os.path.normpath. - _DISALLOWED_NORMALIZED_PREFIXES = ('..' + os.path.sep, os.path.sep) + _DISALLOWED_NORMALIZED_PREFIXES: ClassVar[Tuple[str, ...]] = ( + '..' + os.path.sep, + os.path.sep, + ) # NOTE(kgriffs): If somehow an executable code exploit is triggerable, this # minimizes how much can be included in the payload. - _MAX_NON_PREFIXED_LEN = 512 - - def __init__(self, prefix, directory, downloadable=False, fallback_filename=None): + _MAX_NON_PREFIXED_LEN: ClassVar[int] = 512 + + def __init__( + self, + prefix: str, + directory: Union[str, Path], + downloadable: bool = False, + fallback_filename: Optional[str] = None, + ) -> None: if not prefix.startswith('/'): raise ValueError("prefix must start with '/'") @@ -151,15 +174,15 @@ def __init__(self, prefix, directory, downloadable=False, fallback_filename=None self._prefix = prefix self._downloadable = downloadable - def match(self, path): + def match(self, path: str) -> bool: """Check whether the given path matches this route.""" if self._fallback_filename is None: return path.startswith(self._prefix) return path.startswith(self._prefix) or path == self._prefix[:-1] - def __call__(self, req, resp): + def __call__(self, req: Request, resp: Response, **kw: Any) -> None: """Resource responder for this route.""" - + assert not kw without_prefix = req.path[len(self._prefix) :] # NOTE(kgriffs): Check surrounding whitespace and strip trailing @@ -222,8 +245,8 @@ def __call__(self, req, resp): class StaticRouteAsync(StaticRoute): """Subclass of StaticRoute with modifications to support ASGI apps.""" - async def __call__(self, req, resp): - super().__call__(req, resp) + async def __call__(self, req: asgi.Request, resp: asgi.Response, **kw: Any) -> None: # type: ignore[override] + super().__call__(req, resp, **kw) # NOTE(kgriffs): Fixup resp.stream so that it is non-blocking resp.stream = _AsyncFileReader(resp.stream) @@ -232,7 +255,7 @@ async def __call__(self, req, resp): class _AsyncFileReader: """Adapts a standard file I/O object so that reads are non-blocking.""" - def __init__(self, file): + def __init__(self, file: IO[bytes]) -> None: self._file = file self._loop = asyncio.get_running_loop() diff --git a/falcon/testing/test_case.py b/falcon/testing/test_case.py index 1b07b97f4..1cb95328c 100644 --- a/falcon/testing/test_case.py +++ b/falcon/testing/test_case.py @@ -21,7 +21,7 @@ try: import testtools as unittest except ImportError: # pragma: nocover - import unittest # type: ignore + import unittest import falcon import falcon.request diff --git a/falcon/typing.py b/falcon/typing.py index 4ce772602..817bf2f8d 100644 --- a/falcon/typing.py +++ b/falcon/typing.py @@ -34,9 +34,21 @@ Union, ) +try: + from wsgiref.types import StartResponse as StartResponse + from wsgiref.types import WSGIEnvironment as WSGIEnvironment +except ImportError: + if not TYPE_CHECKING: + WSGIEnvironment = Dict[str, Any] + StartResponse = Callable[[str, List[Tuple[str, str]]], Callable[[bytes], None]] + if TYPE_CHECKING: - from falcon import asgi + from falcon.asgi import Request as AsgiRequest + from falcon.asgi import Response as AsgiResponse + from falcon.asgi import WebSocket from falcon.asgi_spec import AsgiEvent + from falcon.asgi_spec import AsgiSendMsg + from falcon.http_error import HTTPError from falcon.request import Request from falcon.response import Response @@ -52,10 +64,23 @@ class _Missing(Enum): Link = Dict[str, str] # Error handlers -ErrorHandler = Callable[['Request', 'Response', BaseException, dict], Any] +ErrorHandler = Callable[['Request', 'Response', BaseException, Dict[str, Any]], None] + + +class AsgiErrorHandler(Protocol): + async def __call__( + self, + req: AsgiRequest, + resp: Optional[AsgiResponse], + error: BaseException, + params: Dict[str, Any], + *, + ws: Optional[WebSocket] = ..., + ) -> None: ... + # Error serializers -ErrorSerializer = Callable[['Request', 'Response', BaseException], Any] +ErrorSerializer = Callable[['Request', 'Response', 'HTTPError'], None] JSONSerializable = Union[ Dict[str, 'JSONSerializable'], @@ -69,7 +94,18 @@ class _Missing(Enum): ] # Sinks -SinkPrefix = Union[str, Pattern] +SinkPrefix = Union[str, Pattern[str]] + + +class SinkCallable(Protocol): + def __call__(self, req: Request, resp: Response, **kwargs: str) -> None: ... + + +class AsgiSinkCallable(Protocol): + async def __call__( + self, req: AsgiRequest, resp: AsgiResponse, **kwargs: str + ) -> None: ... + # TODO(vytas): Is it possible to specify a Callable or a Protocol that defines # type hints for the two first parameters, but accepts any number of keyword @@ -93,10 +129,22 @@ def __call__( ) -> None: ... +# WSGI class ReadableIO(Protocol): def read(self, n: Optional[int] = ..., /) -> bytes: ... +ProcessRequestMethod = Callable[['Request', 'Response'], None] +ProcessResourceMethod = Callable[ + ['Request', 'Response', Resource, Dict[str, Any]], None +] +ProcessResponseMethod = Callable[['Request', 'Response', Resource, bool], None] + + +class ResponderCallable(Protocol): + def __call__(self, req: Request, resp: Response, **kwargs: Any) -> None: ... + + # ASGI class AsyncReadableIO(Protocol): async def read(self, n: Optional[int] = ..., /) -> bytes: ... @@ -106,12 +154,58 @@ class AsgiResponderMethod(Protocol): async def __call__( self, resource: Resource, - req: asgi.Request, - resp: asgi.Response, + req: AsgiRequest, + resp: AsgiResponse, **kwargs: Any, ) -> None: ... AsgiReceive = Callable[[], Awaitable['AsgiEvent']] +AsgiSend = Callable[['AsgiSendMsg'], Awaitable[None]] +AsgiProcessRequestMethod = Callable[['AsgiRequest', 'AsgiResponse'], Awaitable[None]] +AsgiProcessResourceMethod = Callable[ + ['AsgiRequest', 'AsgiResponse', Resource, Dict[str, Any]], Awaitable[None] +] +AsgiProcessResponseMethod = Callable[ + ['AsgiRequest', 'AsgiResponse', Resource, bool], Awaitable[None] +] +AsgiProcessRequestWsMethod = Callable[['AsgiRequest', 'WebSocket'], Awaitable[None]] +AsgiProcessResourceWsMethod = Callable[ + ['AsgiRequest', 'WebSocket', Resource, Dict[str, Any]], Awaitable[None] +] + + +class AsgiResponderCallable(Protocol): + async def __call__( + self, req: AsgiRequest, resp: AsgiResponse, **kwargs: Any + ) -> None: ... + + +class AsgiResponderWsCallable(Protocol): + async def __call__( + self, req: AsgiRequest, ws: WebSocket, **kwargs: Any + ) -> None: ... + + +# Routing + +MethodDict = Union[ + Dict[str, ResponderCallable], + Dict[str, Union[AsgiResponderCallable, AsgiResponderWsCallable]], +] + + +class FindMethod(Protocol): + def __call__( + self, uri: str, req: Optional[Request] + ) -> Optional[Tuple[object, MethodDict, Dict[str, Any], Optional[str]]]: ... + + +# Media +class SerializeSync(Protocol): + def __call__(self, media: Any, content_type: Optional[str] = ...) -> bytes: ... + + +DeserializeSync = Callable[[bytes], Any] Responder = Union[ResponderMethod, AsgiResponderMethod] diff --git a/falcon/util/__init__.py b/falcon/util/__init__.py index dfb239ce1..03d810e9a 100644 --- a/falcon/util/__init__.py +++ b/falcon/util/__init__.py @@ -59,7 +59,7 @@ # subclass of Morsel. _reserved_cookie_attrs = http_cookies.Morsel._reserved # type: ignore if 'samesite' not in _reserved_cookie_attrs: # pragma: no cover - _reserved_cookie_attrs['samesite'] = 'SameSite' # type: ignore + _reserved_cookie_attrs['samesite'] = 'SameSite' # NOTE(m-mueller): Same for the 'partitioned' attribute that will # probably be added in Python 3.13. if 'partitioned' not in _reserved_cookie_attrs: # pragma: no cover diff --git a/falcon/util/mediatypes.py b/falcon/util/mediatypes.py index c7812bbeb..eebed0446 100644 --- a/falcon/util/mediatypes.py +++ b/falcon/util/mediatypes.py @@ -14,12 +14,14 @@ """Media (aka MIME) type parsing and matching utilities.""" -import typing +from __future__ import annotations + +from typing import Dict, Iterator, Tuple __all__ = ('parse_header',) -def _parse_param_old_stdlib(s): # type: ignore +def _parse_param_old_stdlib(s: str) -> Iterator[str]: while s[:1] == ';': s = s[1:] end = s.find(';') @@ -32,7 +34,7 @@ def _parse_param_old_stdlib(s): # type: ignore s = s[end:] -def _parse_header_old_stdlib(line): # type: ignore +def _parse_header_old_stdlib(line: str) -> Tuple[str, Dict[str, str]]: """Parse a Content-type like header. Return the main content-type and a dictionary of options. @@ -43,7 +45,7 @@ def _parse_header_old_stdlib(line): # type: ignore """ parts = _parse_param_old_stdlib(';' + line) key = parts.__next__() - pdict = {} + pdict: Dict[str, str] = {} for p in parts: i = p.find('=') if i >= 0: @@ -56,7 +58,7 @@ def _parse_header_old_stdlib(line): # type: ignore return key, pdict -def parse_header(line: str) -> typing.Tuple[str, dict]: +def parse_header(line: str) -> Tuple[str, Dict[str, str]]: """Parse a Content-type like header. Return the main content-type and a dictionary of options. diff --git a/falcon/util/misc.py b/falcon/util/misc.py index 05361f0a8..18a27b95e 100644 --- a/falcon/util/misc.py +++ b/falcon/util/misc.py @@ -101,7 +101,7 @@ def decorator(func: Callable) -> Callable: if PYPY: _lru_cache_for_simple_logic = _lru_cache_nop # pragma: nocover else: - _lru_cache_for_simple_logic = functools.lru_cache # type: ignore + _lru_cache_for_simple_logic = functools.lru_cache def is_python_func(func: Union[Callable, Any]) -> bool: @@ -300,7 +300,7 @@ def get_bound_method(obj: object, method_name: str) -> Union[None, Callable[..., return method -def get_argnames(func: Callable) -> List[str]: +def get_argnames(func: Callable[..., Any]) -> List[str]: """Introspect the arguments of a callable. Args: diff --git a/pyproject.toml b/pyproject.toml index 5d1cf5841..73e74558d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ "falcon/vendor", ] disallow_untyped_defs = true + warn_unused_ignores = true [[tool.mypy.overrides]] module = [ @@ -41,9 +42,6 @@ [[tool.mypy.overrides]] module = [ - "falcon.app", - "falcon.asgi._asgi_helpers", - "falcon.asgi.app", "falcon.asgi.multipart", "falcon.asgi.reader", "falcon.asgi.response", diff --git a/tests/asgi/test_hello_asgi.py b/tests/asgi/test_hello_asgi.py index dc1527a34..2ac20d959 100644 --- a/tests/asgi/test_hello_asgi.py +++ b/tests/asgi/test_hello_asgi.py @@ -9,9 +9,9 @@ import falcon.asgi try: - import aiofiles # type: ignore + import aiofiles except ImportError: - aiofiles = None # type: ignore + aiofiles = None # type: ignore[assignment] SIZE_1_KB = 1024 diff --git a/tests/asgi/test_response_media_asgi.py b/tests/asgi/test_response_media_asgi.py index 01236de2e..a0a64f9d9 100644 --- a/tests/asgi/test_response_media_asgi.py +++ b/tests/asgi/test_response_media_asgi.py @@ -10,9 +10,9 @@ from falcon.util.deprecation import DeprecatedWarning try: - import msgpack # type: ignore + import msgpack except ImportError: - msgpack = None # type: ignore + msgpack = None def create_client(resource, handlers=None): diff --git a/tests/asgi/test_ws.py b/tests/asgi/test_ws.py index 1f432ddb0..c4d2a7d1e 100644 --- a/tests/asgi/test_ws.py +++ b/tests/asgi/test_ws.py @@ -14,21 +14,21 @@ from falcon.testing.helpers import _WebSocketState as ClientWebSocketState try: - import cbor2 # type: ignore + import cbor2 except ImportError: - cbor2 = None # type: ignore + cbor2 = None # type: ignore[assignment] try: - import msgpack # type: ignore + import msgpack except ImportError: - msgpack = None # type: ignore + msgpack = None try: - import rapidjson # type: ignore + import rapidjson except ImportError: - rapidjson = None # type: ignore + rapidjson = None # type: ignore[assignment] # NOTE(kgriffs): We do not use codes defined in the framework because we @@ -1346,12 +1346,12 @@ async def process_resource_ws(self, req, ws, res, params): if handler_has_ws: - async def handle_foobar(req, resp, ex, param, ws=None): # type: ignore + async def handle_foobar(req, resp, ex, param, ws=None): raise thing(status) else: - async def handle_foobar(req, resp, ex, param): # type: ignore + async def handle_foobar(req, resp, ex, param): # type: ignore[misc] raise thing(status) conductor.app.add_route('/', Resource()) diff --git a/tests/test_error_handlers.py b/tests/test_error_handlers.py index 751323c20..5bac0842a 100644 --- a/tests/test_error_handlers.py +++ b/tests/test_error_handlers.py @@ -224,9 +224,6 @@ def legacy_handler3(err, rq, rs, prms): client.simulate_head() def test_handler_must_be_coroutine_for_asgi(self, util): - async def legacy_handler(err, rq, rs, prms): - pass - app = util.create_app(True) with util.disable_asgi_non_coroutine_wrapping(): diff --git a/tests/test_hello.py b/tests/test_hello.py index 709b844b0..bb624e7d2 100644 --- a/tests/test_hello.py +++ b/tests/test_hello.py @@ -83,7 +83,7 @@ def close(self): # sometimes bubbles up a warning about exception when trying to call it. class NonClosingBytesIO: # Not callable; test that CloseableStreamIterator ignores it - close = False # type: ignore + close = False def __init__(self, data=b''): self._stream = io.BytesIO(data) diff --git a/tests/test_httperror.py b/tests/test_httperror.py index d05636773..a3d68724f 100644 --- a/tests/test_httperror.py +++ b/tests/test_httperror.py @@ -11,9 +11,9 @@ from falcon.util.deprecation import DeprecatedWarning try: - import yaml # type: ignore + import yaml except ImportError: - yaml = None # type: ignore + yaml = None # type: ignore[assignment] @pytest.fixture diff --git a/tests/test_media_multipart.py b/tests/test_media_multipart.py index c600008a9..277c0a567 100644 --- a/tests/test_media_multipart.py +++ b/tests/test_media_multipart.py @@ -11,7 +11,7 @@ from falcon.util import BufferedReader try: - import msgpack # type: ignore + import msgpack except ImportError: msgpack = None diff --git a/tests/test_request_media.py b/tests/test_request_media.py index f262caeff..0edb8e4eb 100644 --- a/tests/test_request_media.py +++ b/tests/test_request_media.py @@ -10,7 +10,7 @@ import falcon.asgi try: - import msgpack # type: ignore + import msgpack except ImportError: msgpack = None diff --git a/tests/test_response_media.py b/tests/test_response_media.py index 6bf71ab92..bef786922 100644 --- a/tests/test_response_media.py +++ b/tests/test_response_media.py @@ -8,7 +8,7 @@ from falcon import testing try: - import msgpack # type: ignore + import msgpack except ImportError: msgpack = None diff --git a/tests/test_utils.py b/tests/test_utils.py index bc00b8655..159fdd9a7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -26,7 +26,7 @@ from falcon.util.time import TimezoneGMT try: - import msgpack # type: ignore + import msgpack except ImportError: msgpack = None From da3f90f21381f2af56fbab433e8a8cd798e03d9d Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Tue, 27 Aug 2024 23:37:38 +0200 Subject: [PATCH 02/12] typing: type websocket module --- falcon/asgi/ws.py | 110 ++++++++++++++++++++------------------ falcon/constants.py | 8 ++- falcon/testing/helpers.py | 8 ++- pyproject.toml | 1 - 4 files changed, 72 insertions(+), 55 deletions(-) diff --git a/falcon/asgi/ws.py b/falcon/asgi/ws.py index 7971fb868..4cd9a6a76 100644 --- a/falcon/asgi/ws.py +++ b/falcon/asgi/ws.py @@ -2,52 +2,34 @@ import asyncio import collections +from enum import auto from enum import Enum -from typing import ( - Any, - Deque, - Dict, - Iterable, - Mapping, - Optional, - Union, -) +from typing import Any, Deque, Dict, Iterable, Mapping, Optional, Tuple, Union from falcon import errors from falcon import media from falcon import status_codes from falcon.asgi_spec import AsgiEvent +from falcon.asgi_spec import AsgiSendMsg from falcon.asgi_spec import EventType from falcon.asgi_spec import WSCloseCode from falcon.constants import WebSocketPayloadType from falcon.typing import AsgiReceive from falcon.typing import AsgiSend +from falcon.typing import HeaderList from falcon.util import misc -_WebSocketState = Enum('_WebSocketState', 'HANDSHAKE ACCEPTED CLOSED') +__all__ = ('WebSocket',) -__all__ = ('WebSocket',) +class _WebSocketState(Enum): + HANDSHAKE = auto() + ACCEPTED = auto() + CLOSED = auto() class WebSocket: - """Represents a single WebSocket connection with a client. - - Attributes: - ready (bool): ``True`` if the WebSocket connection has been - accepted and the client is still connected, ``False`` otherwise. - unaccepted (bool)): ``True`` if the WebSocket connection has not yet - been accepted, ``False`` otherwise. - closed (bool): ``True`` if the WebSocket connection has been closed - by the server or the client has disconnected. - subprotocols (tuple[str]): The list of subprotocol strings advertised - by the client, or an empty tuple if no subprotocols were - specified. - supports_accept_headers (bool): ``True`` if the ASGI server hosting - the app supports sending headers when accepting the WebSocket - connection, ``False`` otherwise. - - """ + """Represents a single WebSocket connection with a client.""" __slots__ = ( '_asgi_receive', @@ -65,6 +47,13 @@ class WebSocket: 'subprotocols', ) + _state: _WebSocketState + _close_code: Optional[int] + subprotocols: Tuple[str, ...] + """The list of subprotocol strings advertised by the client, or an empty tuple if + no subprotocols were specified. + """ + def __init__( self, ver: str, @@ -105,14 +94,20 @@ def __init__( self._close_reasons = default_close_reasons self._state = _WebSocketState.HANDSHAKE - self._close_code = None # type: Optional[int] + self._close_code = None @property def unaccepted(self) -> bool: + """``True`` if the WebSocket connection has not yet been accepted, + ``False`` otherwise. + """ # noqa: D205 return self._state == _WebSocketState.HANDSHAKE @property def closed(self) -> bool: + """``True`` if the WebSocket connection has been closed by the server or the + client has disconnected. + """ # noqa: D205 return ( self._state == _WebSocketState.CLOSED or self._buffered_receiver.client_disconnected @@ -120,6 +115,9 @@ def closed(self) -> bool: @property def ready(self) -> bool: + """``True`` if the WebSocket connection has been accepted and the client is + still connected, ``False`` otherwise. + """ # noqa: D205 return ( self._state == _WebSocketState.ACCEPTED and not self._buffered_receiver.client_disconnected @@ -127,13 +125,16 @@ def ready(self) -> bool: @property def supports_accept_headers(self) -> bool: + """``True`` if the ASGI server hosting the app supports sending headers when + accepting the WebSocket connection, ``False`` otherwise. + """ # noqa: D205 return self._supports_accept_headers async def accept( self, subprotocol: Optional[str] = None, - headers: Optional[Union[Iterable[Iterable[str]], Mapping[str, str]]] = None, - ): + headers: Optional[HeaderList] = None, + ) -> None: """Accept the incoming WebSocket connection. If, after examining the connection's attributes (headers, advertised @@ -154,7 +155,7 @@ async def accept( client may choose to abandon the connection in this case, if it does not receive an explicit protocol selection. - headers (Iterable[[str, str]]): An iterable of ``[name: str, value: str]`` + headers (HeaderList): An iterable of ``(name: str, value: str)`` two-item iterables, representing a collection of HTTP headers to include in the handshake response. Both *name* and *value* must be of type ``str`` and contain only US-ASCII characters. @@ -199,13 +200,14 @@ async def accept( ) header_items = getattr(headers, 'items', None) - if callable(header_items): - headers = header_items() + headers_iterable: Iterable[tuple[str, str]] = header_items() + else: + headers_iterable = headers # type: ignore[assignment] event['headers'] = parsed_headers = [ (name.lower().encode('ascii'), value.encode('ascii')) - for name, value in headers # type: ignore + for name, value in headers_iterable ] for name, __ in parsed_headers: @@ -348,7 +350,6 @@ async def send_text(self, payload: str) -> None: """ self._require_accepted() - # NOTE(kgriffs): We have to check ourselves because some ASGI # servers are not very strict which can lead to hard-to-debug # errors. @@ -369,14 +370,13 @@ async def send_data(self, payload: Union[bytes, bytearray, memoryview]) -> None: payload (Union[bytes, bytearray, memoryview]): The binary data to send. """ + self._require_accepted() # NOTE(kgriffs): We have to check ourselves because some ASGI # servers are not very strict which can lead to hard-to-debug # errors. if not isinstance(payload, (bytes, bytearray, memoryview)): raise TypeError('payload must be a byte string') - self._require_accepted() - await self._send( { 'type': EventType.WS_SEND, @@ -464,7 +464,7 @@ async def receive_media(self) -> object: return self._mh_bin_deserialize(data) - async def _send(self, msg: dict): + async def _send(self, msg: AsgiSendMsg) -> None: if self._buffered_receiver.client_disconnected: self._state = _WebSocketState.CLOSED self._close_code = self._buffered_receiver.client_disconnected_code @@ -489,7 +489,7 @@ async def _send(self, msg: dict): # obscure the traceback. raise - async def _receive(self) -> dict: + async def _receive(self) -> AsgiEvent: event = await self._asgi_receive() event_type = event['type'] @@ -506,7 +506,7 @@ async def _receive(self) -> dict: return event - def _require_accepted(self): + def _require_accepted(self) -> None: if self._state == _WebSocketState.HANDSHAKE: raise errors.OperationNotAllowed( 'WebSocket connection has not yet been accepted' @@ -514,7 +514,7 @@ def _require_accepted(self): elif self._state == _WebSocketState.CLOSED: raise errors.WebSocketDisconnected(self._close_code) - def _translate_webserver_error(self, ex): + def _translate_webserver_error(self, ex: Exception) -> Optional[Exception]: s = str(ex) # NOTE(kgriffs): uvicorn or any other server using the "websockets" @@ -656,13 +656,20 @@ class _BufferedReceiver: 'client_disconnected_code', ] - def __init__(self, asgi_receive: AsgiReceive, max_queue: int): + _pop_message_waiter: Optional[asyncio.Future[None]] + _put_message_waiter: Optional[asyncio.Future[None]] + _pump_task: Optional[asyncio.Task[None]] + _messages: Deque[AsgiEvent] + client_disconnected: bool + client_disconnected_code: Optional[int] + + def __init__(self, asgi_receive: AsgiReceive, max_queue: int) -> None: self._asgi_receive = asgi_receive self._max_queue = max_queue self._loop = asyncio.get_running_loop() - self._messages: Deque[AsgiEvent] = collections.deque() + self._messages = collections.deque() self._pop_message_waiter = None self._put_message_waiter = None @@ -671,12 +678,12 @@ def __init__(self, asgi_receive: AsgiReceive, max_queue: int): self.client_disconnected = False self.client_disconnected_code = None - def start(self): - if not self._pump_task: + def start(self) -> None: + if self._pump_task is None: self._pump_task = asyncio.create_task(self._pump()) - async def stop(self): - if not self._pump_task: + async def stop(self) -> None: + if self._pump_task is None: return self._pump_task.cancel() @@ -687,13 +694,14 @@ async def stop(self): self._pump_task = None - async def receive(self): + async def receive(self) -> AsgiEvent: # NOTE(kgriffs): Since this class is only used internally, we # use an assertion to mitigate against framework bugs. # # receive() may not be called again while another coroutine # is already waiting for the next message. - assert not self._pop_message_waiter + assert self._pop_message_waiter is None + assert self._pump_task is not None # NOTE(kgriffs): Wait for a message if none are available. This pattern # was borrowed from the websockets.protocol module. @@ -737,7 +745,7 @@ async def receive(self): return message - async def _pump(self): + async def _pump(self) -> None: while not self.client_disconnected: received_event = await self._asgi_receive() if received_event['type'] == EventType.WS_DISCONNECT: diff --git a/falcon/constants.py b/falcon/constants.py index 9576f0630..dbbb94934 100644 --- a/falcon/constants.py +++ b/falcon/constants.py @@ -1,3 +1,4 @@ +from enum import auto from enum import Enum import os import sys @@ -187,5 +188,8 @@ _UNSET = object() # TODO: remove once replaced with missing -WebSocketPayloadType = Enum('WebSocketPayloadType', 'TEXT BINARY') -"""Enum representing the two possible WebSocket payload types.""" +class WebSocketPayloadType(Enum): + """Enum representing the two possible WebSocket payload types.""" + + TEXT = auto() + BINARY = auto() diff --git a/falcon/testing/helpers.py b/falcon/testing/helpers.py index f2cbcb45b..e21961125 100644 --- a/falcon/testing/helpers.py +++ b/falcon/testing/helpers.py @@ -26,6 +26,7 @@ from collections import defaultdict from collections import deque import contextlib +from enum import auto from enum import Enum import io import itertools @@ -365,7 +366,12 @@ async def collect(self, event: Dict[str, Any]): __call__ = collect -_WebSocketState = Enum('_WebSocketState', 'CONNECT HANDSHAKE ACCEPTED DENIED CLOSED') +class _WebSocketState(Enum): + CONNECT = auto() + HANDSHAKE = auto() + ACCEPTED = auto() + DENIED = auto() + CLOSED = auto() class ASGIWebSocketSimulator: diff --git a/pyproject.toml b/pyproject.toml index 73e74558d..f70b47677 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,6 @@ "falcon.asgi.reader", "falcon.asgi.response", "falcon.asgi.stream", - "falcon.asgi.ws", "falcon.media.json", "falcon.media.msgpack", "falcon.media.multipart", From 117cccd437d70743e921cb9f1a8c6c0234488a56 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Tue, 27 Aug 2024 23:44:58 +0200 Subject: [PATCH 03/12] typing: type asgi.reader, asgi.structures, asgi.stream --- falcon/asgi/reader.py | 92 +++++++++++++++++++++++++-------------- falcon/asgi/stream.py | 58 +++++++++++++++--------- falcon/asgi/structures.py | 56 ++++++++++++++---------- pyproject.toml | 2 - 4 files changed, 131 insertions(+), 77 deletions(-) diff --git a/falcon/asgi/reader.py b/falcon/asgi/reader.py index 426668809..281e607c4 100644 --- a/falcon/asgi/reader.py +++ b/falcon/asgi/reader.py @@ -14,7 +14,10 @@ """Buffered ASGI stream reader.""" +from __future__ import annotations + import io +from typing import AsyncIterator, List, NoReturn, Optional, Protocol from falcon.errors import DelimiterError from falcon.errors import OperationNotAllowed @@ -45,7 +48,17 @@ class BufferedReader: '_source', ] - def __init__(self, source, chunk_size=None): + _buffer: bytes + _buffer_len: int + _buffer_pos: int + _chunk_size: int + _consumed: int + _exhausted: bool + _iteration_started: bool + _max_join_size: int + _source: AsyncIterator[bytes] + + def __init__(self, source: AsyncIterator[bytes], chunk_size: Optional[int] = None): self._source = self._iter_normalized(source) self._chunk_size = chunk_size or DEFAULT_CHUNK_SIZE self._max_join_size = self._chunk_size * _MAX_JOIN_CHUNKS @@ -57,7 +70,9 @@ def __init__(self, source, chunk_size=None): self._exhausted = False self._iteration_started = False - async def _iter_normalized(self, source): + async def _iter_normalized( + self, source: AsyncIterator[bytes] + ) -> AsyncIterator[bytes]: chunk = b'' chunk_size = self._chunk_size @@ -77,7 +92,7 @@ async def _iter_normalized(self, source): self._exhausted = True - async def _iter_with_buffer(self, size_hint=0): + async def _iter_with_buffer(self, size_hint: int = 0) -> AsyncIterator[bytes]: if self._buffer_len > self._buffer_pos: if 0 < size_hint < self._buffer_len - self._buffer_pos: buffer_pos = self._buffer_pos @@ -91,7 +106,9 @@ async def _iter_with_buffer(self, size_hint=0): async for chunk in self._source: yield chunk - async def _iter_delimited(self, delimiter, size_hint=0): + async def _iter_delimited( + self, delimiter: bytes, size_hint: int = 0 + ) -> AsyncIterator[bytes]: delimiter_len_1 = len(delimiter) - 1 if not 0 <= delimiter_len_1 < self._chunk_size: raise ValueError('delimiter length must be within [1, chunk_size]') @@ -152,13 +169,13 @@ async def _iter_delimited(self, delimiter, size_hint=0): yield self._buffer - async def _consume_delimiter(self, delimiter): + async def _consume_delimiter(self, delimiter: bytes) -> None: delimiter_len = len(delimiter) if await self.peek(delimiter_len) != delimiter: raise DelimiterError('expected delimiter missing') self._buffer_pos += delimiter_len - def _prepend_buffer(self, chunk): + def _prepend_buffer(self, chunk: bytes) -> None: if self._buffer_len > self._buffer_pos: self._buffer = chunk + self._buffer[self._buffer_pos :] self._buffer_len = len(self._buffer) @@ -168,17 +185,17 @@ def _prepend_buffer(self, chunk): self._buffer_pos = 0 - def _trim_buffer(self): + def _trim_buffer(self) -> None: self._buffer = self._buffer[self._buffer_pos :] self._buffer_len -= self._buffer_pos self._buffer_pos = 0 - async def _read_from(self, source, size=-1): + async def _read_from(self, source: AsyncIterator[bytes], size: int = -1) -> bytes: if size == -1 or size is None: - result = io.BytesIO() + result_bytes = io.BytesIO() async for chunk in source: - result.write(chunk) - return result.getvalue() + result_bytes.write(chunk) + return result_bytes.getvalue() if size <= 0: return b'' @@ -186,7 +203,7 @@ async def _read_from(self, source, size=-1): remaining = size if size <= self._max_join_size: - result = [] + result: List[bytes] = [] async for chunk in source: chunk_len = len(chunk) if remaining < chunk_len: @@ -203,29 +220,29 @@ async def _read_from(self, source, size=-1): return result[0] if len(result) == 1 else b''.join(result) # NOTE(vytas): size > self._max_join_size - result = io.BytesIO() + result_bytes = io.BytesIO() async for chunk in source: chunk_len = len(chunk) if remaining < chunk_len: - result.write(chunk[:remaining]) + result_bytes.write(chunk[:remaining]) self._prepend_buffer(chunk[remaining:]) break - result.write(chunk) + result_bytes.write(chunk) remaining -= chunk_len if remaining == 0: # pragma: no py39,py310 cover break - return result.getvalue() + return result_bytes.getvalue() - def delimit(self, delimiter): + def delimit(self, delimiter: bytes) -> BufferedReader: # TODO: should se self return type(self)(self._iter_delimited(delimiter), chunk_size=self._chunk_size) # ------------------------------------------------------------------------- # Asynchronous IO interface. # ------------------------------------------------------------------------- - def __aiter__(self): + def __aiter__(self) -> AsyncIterator[bytes]: if self._iteration_started: raise OperationNotAllowed('This stream is already being iterated over.') @@ -236,10 +253,10 @@ def __aiter__(self): return self._iter_with_buffer() return self._source - async def exhaust(self): + async def exhaust(self) -> None: await self.pipe() - async def peek(self, size=-1): + async def peek(self, size: int = -1) -> bytes: if size < 0 or size > self._chunk_size: size = self._chunk_size @@ -255,12 +272,17 @@ async def peek(self, size=-1): return self._buffer[:size] - async def pipe(self, destination=None): + async def pipe(self, destination: Optional[AsyncWritableIO] = None) -> None: async for chunk in self._iter_with_buffer(): if destination is not None: await destination.write(chunk) - async def pipe_until(self, delimiter, destination=None, consume_delimiter=False): + async def pipe_until( + self, + delimiter: bytes, + destination: Optional[AsyncWritableIO] = None, + consume_delimiter: bool = False, + ) -> None: async for chunk in self._iter_delimited(delimiter): if destination is not None: await destination.write(chunk) @@ -268,10 +290,10 @@ async def pipe_until(self, delimiter, destination=None, consume_delimiter=False) if consume_delimiter: await self._consume_delimiter(delimiter) - async def read(self, size=-1): + async def read(self, size: int = -1) -> bytes: return await self._read_from(self._iter_with_buffer(size_hint=size or 0), size) - async def readall(self): + async def readall(self) -> bytes: """Read and return all remaining data in the request body. Warning: @@ -286,7 +308,9 @@ async def readall(self): """ return await self._read_from(self._iter_with_buffer()) - async def read_until(self, delimiter, size=-1, consume_delimiter=False): + async def read_until( + self, delimiter: bytes, size: int = -1, consume_delimiter: bool = False + ) -> bytes: result = await self._read_from( self._iter_delimited(delimiter, size_hint=size or 0), size ) @@ -306,30 +330,34 @@ async def read_until(self, delimiter, size=-1, consume_delimiter=False): # pass @property - def eof(self): + def eof(self) -> bool: """Whether the stream is at EOF.""" return self._exhausted and self._buffer_len == self._buffer_pos - def fileno(self): + def fileno(self) -> NoReturn: """Raise an instance of OSError since a file descriptor is not used.""" raise OSError('This IO object does not use a file descriptor') - def isatty(self): + def isatty(self) -> bool: """Return ``False`` always.""" return False - def readable(self): + def readable(self) -> bool: """Return ``True`` always.""" return True - def seekable(self): + def seekable(self) -> bool: """Return ``False`` always.""" return False - def writable(self): + def writable(self) -> bool: """Return ``False`` always.""" return False - def tell(self): + def tell(self) -> int: """Return the number of bytes read from the stream so far.""" return self._consumed - (self._buffer_len - self._buffer_pos) + + +class AsyncWritableIO(Protocol): + async def write(self, data: bytes, /) -> None: ... diff --git a/falcon/asgi/stream.py b/falcon/asgi/stream.py index bd532feab..6213b1da1 100644 --- a/falcon/asgi/stream.py +++ b/falcon/asgi/stream.py @@ -14,7 +14,13 @@ """ASGI BoundedStream class.""" +from __future__ import annotations + +from typing import AsyncIterator, NoReturn, Optional + +from falcon.asgi_spec import AsgiEvent from falcon.errors import OperationNotAllowed +from falcon.typing import AsgiReceive __all__ = ('BoundedStream',) @@ -94,16 +100,28 @@ class BoundedStream: from the Content-Length header in the request (if available). """ - __slots__ = [ + __slots__ = ( '_buffer', '_bytes_remaining', '_closed', '_iteration_started', '_pos', '_receive', - ] - - def __init__(self, receive, first_event=None, content_length=None): + ) + + _buffer: bytes + _bytes_remaining: int + _closed: bool + _iteration_started: bool + _pos: int + _receive: AsgiReceive + + def __init__( + self, + receive: AsgiReceive, + first_event: Optional[AsgiEvent] = None, + content_length: Optional[int] = None, + ) -> None: self._closed = False self._iteration_started = False @@ -115,7 +133,7 @@ def __init__(self, receive, first_event=None, content_length=None): # object is created in other cases, use "in" here rather than # EAFP. if first_event and 'body' in first_event: - first_chunk = first_event['body'] + first_chunk: bytes = first_event['body'] else: first_chunk = b'' @@ -144,7 +162,7 @@ def __init__(self, receive, first_event=None, content_length=None): if not ('more_body' in first_event and first_event['more_body']): self._bytes_remaining = 0 - def __aiter__(self): + def __aiter__(self) -> AsyncIterator[bytes]: # NOTE(kgriffs): This returns an async generator, but that's OK because # it also implements the iterator protocol defined in PEP 492, albeit # in a more efficient way than a regular async iterator. @@ -161,41 +179,41 @@ def __aiter__(self): # readlines(), __iter__(), __next__(), flush(), seek(), # truncate(), __del__(). - def fileno(self): + def fileno(self) -> NoReturn: """Raise an instance of OSError since a file descriptor is not used.""" raise OSError('This IO object does not use a file descriptor') - def isatty(self): + def isatty(self) -> bool: """Return ``False`` always.""" return False - def readable(self): + def readable(self) -> bool: """Return ``True`` always.""" return True - def seekable(self): + def seekable(self) -> bool: """Return ``False`` always.""" return False - def writable(self): + def writable(self) -> bool: """Return ``False`` always.""" return False - def tell(self): + def tell(self) -> int: """Return the number of bytes read from the stream so far.""" return self._pos @property - def closed(self): + def closed(self) -> bool: return self._closed # ------------------------------------------------------------------------- @property - def eof(self): + def eof(self) -> bool: return not self._buffer and self._bytes_remaining == 0 - def close(self): + def close(self) -> None: """Clear any buffered data and close this stream. Once the stream is closed, any operation on it will @@ -211,7 +229,7 @@ def close(self): self._closed = True - async def exhaust(self): + async def exhaust(self) -> None: """Consume and immediately discard any remaining data in the stream.""" if self._closed: @@ -240,13 +258,13 @@ async def exhaust(self): self._bytes_remaining = 0 # Immediately dereference the data so it can be discarded ASAP - event = None + event = None # type: ignore[assignment] # NOTE(kgriffs): Ensure that if we read more than expected, this # value is normalized to zero. self._bytes_remaining = 0 - async def readall(self): + async def readall(self) -> bytes: """Read and return all remaining data in the request body. Warning: @@ -308,7 +326,7 @@ async def readall(self): return data - async def read(self, size=None): + async def read(self, size: Optional[int] = None) -> bytes: """Read some or all of the remaining bytes in the request body. Warning: @@ -401,7 +419,7 @@ async def read(self, size=None): return data - async def _iter_content(self): + async def _iter_content(self) -> AsyncIterator[bytes]: if self._closed: raise OperationNotAllowed( 'This stream is closed; no further operations on it are permitted.' diff --git a/falcon/asgi/structures.py b/falcon/asgi/structures.py index 22ebc1b7a..7e66c310b 100644 --- a/falcon/asgi/structures.py +++ b/falcon/asgi/structures.py @@ -38,31 +38,8 @@ class SSEvent: in any event that would otherwise be blank (i.e., one that does not specify any fields when initializing the `SSEvent` instance.) - Attributes: - data (bytes): Raw byte string to use as the ``data`` field for the - event message. Takes precedence over both `text` and `json`. - text (str): String to use for the ``data`` field in the message. Will - be encoded as UTF-8 in the event. Takes precedence over `json`. - json (object): JSON-serializable object to be converted to JSON and - used as the ``data`` field in the event message. - event (str): A string identifying the event type (AKA event name). - event_id (str): The event ID that the User Agent should use for - the `EventSource` object's last event ID value. - retry (int): The reconnection time to use when attempting to send the - event. This must be an integer, specifying the reconnection time - in milliseconds. - comment (str): Comment to include in the event message; this is - normally ignored by the user agent, but is useful when composing - a periodic "ping" message to keep the connection alive. Since this - is a common use case, a default "ping" comment will be included - in any event that would otherwise be blank (i.e., one that does - not specify any of these fields when initializing the - `SSEvent` instance.) - - .. _Server-Sent Events: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events - """ __slots__ = [ @@ -75,6 +52,39 @@ class SSEvent: 'comment', ] + data: Optional[bytes] + """Raw byte string to use as the ``data`` field for the event message. + Takes precedence over both `text` and `json`. + """ + text: Optional[str] + """String to use for the ``data`` field in the message. + Will be encoded as UTF-8 in the event. Takes precedence over `json`. + """ + json: JSONSerializable + """JSON-serializable object to be converted to JSON and used as the ``data`` + field in the event message. + """ + event: Optional[str] + """A string identifying the event type (AKA event name).""" + event_id: Optional[str] + """The event ID that the User Agent should use for the `EventSource` object's + last event ID value. + """ + retry: Optional[int] + """The reconnection time to use when attempting to send the event. + + This must be an integer, specifying the reconnection time in milliseconds. + """ + comment: Optional[str] + """Comment to include in the event message. + + This is normally ignored by the user agent, but is useful when composing a periodic + "ping" message to keep the connection alive. Since this is a common use case, a + default "ping" comment will be included in any event that would otherwise be blank + (i.e., one that does not specify any of the fields when initializing the + :class:`SSEvent` instance.) + """ + def __init__( self, data: Optional[bytes] = None, diff --git a/pyproject.toml b/pyproject.toml index f70b47677..c69c0f04a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,9 +43,7 @@ [[tool.mypy.overrides]] module = [ "falcon.asgi.multipart", - "falcon.asgi.reader", "falcon.asgi.response", - "falcon.asgi.stream", "falcon.media.json", "falcon.media.msgpack", "falcon.media.multipart", From ef236f67adc5418c784ce6230d53b0dc13301fee Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Tue, 27 Aug 2024 23:48:16 +0200 Subject: [PATCH 04/12] typing: type most of media --- falcon/media/json.py | 61 +++++++++++++++++++++++++++----------- falcon/media/msgpack.py | 48 ++++++++++++++++++++++-------- falcon/media/urlencoded.py | 31 ++++++++++++++----- pyproject.toml | 3 -- 4 files changed, 103 insertions(+), 40 deletions(-) diff --git a/falcon/media/json.py b/falcon/media/json.py index cf0111e82..502be0126 100644 --- a/falcon/media/json.py +++ b/falcon/media/json.py @@ -1,10 +1,15 @@ +from __future__ import annotations + from functools import partial import json +from typing import Any, Callable, Optional, Union from falcon import errors from falcon import http_error from falcon.media.base import BaseHandler from falcon.media.base import TextBaseHandlerWS +from falcon.typing import AsyncReadableIO +from falcon.typing import ReadableIO class JSONHandler(BaseHandler): @@ -148,7 +153,11 @@ def default(self, obj): loads (func): Function to use when deserializing JSON requests. """ - def __init__(self, dumps=None, loads=None): + def __init__( + self, + dumps: Optional[Callable[[Any], Union[str, bytes]]] = None, + loads: Optional[Callable[[str], Any]] = None, + ) -> None: self._dumps = dumps or partial(json.dumps, ensure_ascii=False) self._loads = loads or json.loads @@ -156,11 +165,11 @@ def __init__(self, dumps=None, loads=None): # proper serialize implementation. result = self._dumps({'message': 'Hello World'}) if isinstance(result, str): - self.serialize = self._serialize_s - self.serialize_async = self._serialize_async_s + self.serialize = self._serialize_s # type: ignore[method-assign] + self.serialize_async = self._serialize_async_s # type: ignore[method-assign] else: - self.serialize = self._serialize_b - self.serialize_async = self._serialize_async_b + self.serialize = self._serialize_b # type: ignore[method-assign] + self.serialize_async = self._serialize_async_b # type: ignore[method-assign] # NOTE(kgriffs): To be safe, only enable the optimized protocol when # not subclassed. @@ -168,7 +177,7 @@ def __init__(self, dumps=None, loads=None): self._serialize_sync = self.serialize self._deserialize_sync = self._deserialize - def _deserialize(self, data): + def _deserialize(self, data: bytes) -> Any: if not data: raise errors.MediaNotFoundError('JSON') try: @@ -176,27 +185,41 @@ def _deserialize(self, data): except ValueError as err: raise errors.MediaMalformedError('JSON') from err - def deserialize(self, stream, content_type, content_length): + def deserialize( + self, + stream: ReadableIO, + content_type: Optional[str], + content_length: Optional[int], + ) -> Any: return self._deserialize(stream.read()) - async def deserialize_async(self, stream, content_type, content_length): + async def deserialize_async( + self, + stream: AsyncReadableIO, + content_type: Optional[str], + content_length: Optional[int], + ) -> Any: return self._deserialize(await stream.read()) # NOTE(kgriffs): Make content_type a kwarg to support the # Request.render_body() shortcut optimization. - def _serialize_s(self, media, content_type=None) -> bytes: - return self._dumps(media).encode() + def _serialize_s(self, media: Any, content_type: Optional[str] = None) -> bytes: + return self._dumps(media).encode() # type: ignore[union-attr] - async def _serialize_async_s(self, media, content_type) -> bytes: - return self._dumps(media).encode() + async def _serialize_async_s( + self, media: Any, content_type: Optional[str] + ) -> bytes: + return self._dumps(media).encode() # type: ignore[union-attr] # NOTE(kgriffs): Make content_type a kwarg to support the # Request.render_body() shortcut optimization. - def _serialize_b(self, media, content_type=None) -> bytes: - return self._dumps(media) + def _serialize_b(self, media: Any, content_type: Optional[str] = None) -> bytes: + return self._dumps(media) # type: ignore[return-value] - async def _serialize_async_b(self, media, content_type) -> bytes: - return self._dumps(media) + async def _serialize_async_b( + self, media: Any, content_type: Optional[str] + ) -> bytes: + return self._dumps(media) # type: ignore[return-value] class JSONHandlerWS(TextBaseHandlerWS): @@ -257,7 +280,11 @@ class JSONHandlerWS(TextBaseHandlerWS): __slots__ = ['dumps', 'loads'] - def __init__(self, dumps=None, loads=None): + def __init__( + self, + dumps: Optional[Callable[[Any], str]] = None, + loads: Optional[Callable[[str], Any]] = None, + ) -> None: self._dumps = dumps or partial(json.dumps, ensure_ascii=False) self._loads = loads or json.loads diff --git a/falcon/media/msgpack.py b/falcon/media/msgpack.py index 0267e2511..5b8c587c9 100644 --- a/falcon/media/msgpack.py +++ b/falcon/media/msgpack.py @@ -1,10 +1,12 @@ -from __future__ import absolute_import # NOTE(kgriffs): Work around a Cython bug +from __future__ import annotations -from typing import Union +from typing import Any, Callable, Optional, Protocol from falcon import errors from falcon.media.base import BaseHandler from falcon.media.base import BinaryBaseHandlerWS +from falcon.typing import AsyncReadableIO +from falcon.typing import ReadableIO class MessagePackHandler(BaseHandler): @@ -28,7 +30,10 @@ class MessagePackHandler(BaseHandler): $ pip install msgpack """ - def __init__(self): + _pack: Callable[[Any], bytes] + _unpackb: UnpackMethod + + def __init__(self) -> None: import msgpack packer = msgpack.Packer(autoreset=True, use_bin_type=True) @@ -38,10 +43,10 @@ def __init__(self): # NOTE(kgriffs): To be safe, only enable the optimized protocol when # not subclassed. if type(self) is MessagePackHandler: - self._serialize_sync = self._pack + self._serialize_sync = self._pack # type: ignore[assignment] self._deserialize_sync = self._deserialize - def _deserialize(self, data): + def _deserialize(self, data: bytes) -> Any: if not data: raise errors.MediaNotFoundError('MessagePack') try: @@ -51,16 +56,26 @@ def _deserialize(self, data): except ValueError as err: raise errors.MediaMalformedError('MessagePack') from err - def deserialize(self, stream, content_type, content_length): + def deserialize( + self, + stream: ReadableIO, + content_type: Optional[str], + content_length: Optional[int], + ) -> Any: return self._deserialize(stream.read()) - async def deserialize_async(self, stream, content_type, content_length): + async def deserialize_async( + self, + stream: AsyncReadableIO, + content_type: Optional[str], + content_length: Optional[int], + ) -> Any: return self._deserialize(await stream.read()) - def serialize(self, media, content_type) -> bytes: + def serialize(self, media: Any, content_type: Optional[str]) -> bytes: return self._pack(media) - async def serialize_async(self, media, content_type) -> bytes: + async def serialize_async(self, media: Any, content_type: Optional[str]) -> bytes: return self._pack(media) @@ -81,19 +96,26 @@ class MessagePackHandlerWS(BinaryBaseHandlerWS): $ pip install msgpack """ - __slots__ = ['msgpack', 'packer'] + __slots__ = ('msgpack', 'packer') + + _pack: Callable[[Any], bytes] + _unpackb: UnpackMethod - def __init__(self): + def __init__(self) -> None: import msgpack packer = msgpack.Packer(autoreset=True, use_bin_type=True) self._pack = packer.pack self._unpackb = msgpack.unpackb - def serialize(self, media: object) -> Union[bytes, bytearray, memoryview]: + def serialize(self, media: object) -> bytes: return self._pack(media) - def deserialize(self, payload: bytes) -> object: + def deserialize(self, payload: bytes) -> Any: # NOTE(jmvrbanac): Using unpackb since we would need to manage # a buffer for Unpacker() which wouldn't gain us much. return self._unpackb(payload, raw=False) + + +class UnpackMethod(Protocol): + def __call__(self, data: bytes, raw: bool = ...) -> Any: ... diff --git a/falcon/media/urlencoded.py b/falcon/media/urlencoded.py index 17f73dd65..1d7f6cb04 100644 --- a/falcon/media/urlencoded.py +++ b/falcon/media/urlencoded.py @@ -1,7 +1,12 @@ +from __future__ import annotations + +from typing import Any, Optional from urllib.parse import urlencode from falcon import errors from falcon.media.base import BaseHandler +from falcon.typing import AsyncReadableIO +from falcon.typing import ReadableIO from falcon.util.uri import parse_query_string @@ -28,7 +33,7 @@ class URLEncodedFormHandler(BaseHandler): when deserializing. """ - def __init__(self, keep_blank=True, csv=False): + def __init__(self, keep_blank: bool = True, csv: bool = False) -> None: self._keep_blank = keep_blank self._csv = csv @@ -40,23 +45,35 @@ def __init__(self, keep_blank=True, csv=False): # NOTE(kgriffs): Make content_type a kwarg to support the # Request.render_body() shortcut optimization. - def serialize(self, media, content_type=None) -> bytes: + def serialize(self, media: Any, content_type: Optional[str] = None) -> bytes: # NOTE(vytas): Setting doseq to True to mirror the parse_query_string # behaviour. return urlencode(media, doseq=True).encode() - def _deserialize(self, body): + def _deserialize(self, body: bytes) -> Any: try: # NOTE(kgriffs): According to http://goo.gl/6rlcux the # body should be US-ASCII. Enforcing this also helps # catch malicious input. - body = body.decode('ascii') - return parse_query_string(body, keep_blank=self._keep_blank, csv=self._csv) + body_str = body.decode('ascii') + return parse_query_string( + body_str, keep_blank=self._keep_blank, csv=self._csv + ) except Exception as err: raise errors.MediaMalformedError('URL-encoded') from err - def deserialize(self, stream, content_type, content_length): + def deserialize( + self, + stream: ReadableIO, + content_type: Optional[str], + content_length: Optional[int], + ) -> Any: return self._deserialize(stream.read()) - async def deserialize_async(self, stream, content_type, content_length): + async def deserialize_async( + self, + stream: AsyncReadableIO, + content_type: Optional[str], + content_length: Optional[int], + ) -> Any: return self._deserialize(await stream.read()) diff --git a/pyproject.toml b/pyproject.toml index c69c0f04a..43eed3128 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,10 +44,7 @@ module = [ "falcon.asgi.multipart", "falcon.asgi.response", - "falcon.media.json", - "falcon.media.msgpack", "falcon.media.multipart", - "falcon.media.urlencoded", "falcon.media.validators.*", "falcon.responders", "falcon.response_helpers", From 4da19535a6ab33258ad5f8c7e6dd8aeab07c0556 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Tue, 27 Aug 2024 23:50:36 +0200 Subject: [PATCH 05/12] typing: type multipart --- docs/api/multipart.rst | 24 ++- falcon/asgi/multipart.py | 133 +++++++++++- falcon/asgi/reader.py | 17 +- falcon/media/multipart.py | 370 +++++++++++++++++----------------- falcon/request.py | 2 +- falcon/typing.py | 2 + falcon/util/misc.py | 8 +- falcon/util/reader.py | 2 +- pyproject.toml | 2 - tests/test_media_multipart.py | 2 +- 10 files changed, 354 insertions(+), 208 deletions(-) diff --git a/docs/api/multipart.rst b/docs/api/multipart.rst index 8eed74f61..cf2620414 100644 --- a/docs/api/multipart.rst +++ b/docs/api/multipart.rst @@ -69,12 +69,26 @@ default, allowing you to use ``req.get_media()`` to iterate over the `. Falcon offers straightforward support for all of these scenarios. -Body Part Type --------------- +Multipart Form and Body Part Types +---------------------------------- -.. autoclass:: falcon.media.multipart.BodyPart - :members: - :exclude-members: data, media, text +.. tabs:: + + .. group-tab:: WSGI + + .. autoclass:: falcon.media.multipart.MultipartForm + :members: + + .. autoclass:: falcon.media.multipart.BodyPart + :members: + + .. group-tab:: ASGI + + .. autoclass:: falcon.asgi.multipart.MultipartForm + :members: + + .. autoclass:: falcon.asgi.multipart.BodyPart + :members: .. _multipart_parser_conf: diff --git a/falcon/asgi/multipart.py b/falcon/asgi/multipart.py index 52cb0505e..403028b26 100644 --- a/falcon/asgi/multipart.py +++ b/falcon/asgi/multipart.py @@ -14,11 +14,27 @@ """ASGI multipart form media handler components.""" +from __future__ import annotations + +from typing import ( + Any, + AsyncIterator, + Awaitable, + Dict, + Optional, + TYPE_CHECKING, +) + from falcon.asgi.reader import BufferedReader from falcon.errors import DelimiterError from falcon.media import multipart +from falcon.typing import AsyncReadableIO +from falcon.typing import MISSING from falcon.util.mediatypes import parse_header +if TYPE_CHECKING: + from falcon.media.multipart import MultipartParseOptions + _ALLOWED_CONTENT_HEADERS = multipart._ALLOWED_CONTENT_HEADERS _CRLF = multipart._CRLF _CRLF_CRLF = multipart._CRLF_CRLF @@ -27,7 +43,43 @@ class BodyPart(multipart.BodyPart): - async def get_data(self): + """Represents a body part in a multipart form in a ASGI application. + + Note: + :class:`BodyPart` is meant to be instantiated directly only by the + :class:`MultipartFormHandler` parser. + """ + + if TYPE_CHECKING: + + def __init__( + self, + stream: BufferedReader, + headers: Dict[bytes, bytes], + parse_options: MultipartParseOptions, + ): ... + + stream: BufferedReader # type: ignore[assignment] + """File-like input object for reading the body part of the + multipart form request, if any. This object provides direct access + to the server's data stream and is non-seekable. The stream is + automatically delimited according to the multipart stream boundary. + + With the exception of being buffered to keep track of the boundary, + the wrapped body part stream interface and behavior mimic + :attr:`Request.stream `. + + Similarly to :attr:`BoundedStream `, + the most efficient way to read the body part content is asynchronous + iteration over part data chunks: + + .. code:: python + + async for data_chunk in part.stream: + pass + """ + + async def get_data(self) -> bytes: # type: ignore[override] if self._data is None: max_size = self._parse_options.max_body_part_buffer_size + 1 self._data = await self.stream.read(max_size) @@ -36,8 +88,21 @@ async def get_data(self): return self._data - async def get_media(self): - if self._media is None: + async def get_media(self) -> Any: + """Return a deserialized form of the multipart body part. + + When called, this method will attempt to deserialize the body part + stream using the Content-Type header as well as the media-type handlers + configured via :class:`~falcon.media.multipart.MultipartParseOptions`. + + The result will be cached and returned in subsequent calls:: + + deserialized_media = await part.get_media() + + Returns: + object: The deserialized media representation. + """ + if self._media is MISSING: handler, _, _ = self._parse_options.media_handlers._resolve( self.content_type, 'text/plain' ) @@ -52,7 +117,7 @@ async def get_media(self): return self._media - async def get_text(self): + async def get_text(self) -> Optional[str]: # type: ignore[override] content_type, options = parse_header(self.content_type) if content_type != 'text/plain': return None @@ -65,13 +130,61 @@ async def get_text(self): description='invalid text or charset: {}'.format(charset) ) from err - data = property(get_data) - media = property(get_media) - text = property(get_text) + data: Awaitable[bytes] = property(get_data) # type: ignore[assignment] + """Property that acts as a convenience alias for :meth:`~.get_data`. + + The ``await`` keyword must still be added when referencing + the property:: + + # Equivalent to: content = await part.get_data() + content = await part.data + """ + media: Awaitable[Any] = property(get_media) # type: ignore[assignment] + """Property that acts as a convenience alias for :meth:`~.get_media`. + + The ``await`` keyword must still be added when referencing + the property:: + + # Equivalent to: deserialized_media = await part.get_media() + deserialized_media = await part.media + """ + text: Awaitable[bytes] = property(get_text) # type: ignore[assignment] + """Property that acts as a convenience alias for :meth:`~.get_text`. + + The ``await`` keyword must still be added when referencing + the property:: + + # Equivalent to: decoded_text = await part.get_text() + decoded_text = await part.text + """ class MultipartForm: - def __init__(self, stream, boundary, content_length, parse_options): + """Iterable object that returns each form part as :class:`BodyPart` instances. + + Typical usage illustrated below:: + + async def on_post(self, req: Request, resp: Response) -> None: + form: MultipartForm = await req.get_media() + + async for part in form: + if part.name == 'foo': + ... + else: + ... + + Note: + :class:`MultipartForm` is meant to be instantiated directly only by the + :class:`MultipartFormHandler` parser. + """ + + def __init__( + self, + stream: AsyncReadableIO, + boundary: bytes, + content_length: Optional[int], + parse_options: MultipartParseOptions, + ) -> None: self._stream = ( stream if isinstance(stream, BufferedReader) else BufferedReader(stream) ) @@ -83,10 +196,10 @@ def __init__(self, stream, boundary, content_length, parse_options): self._dash_boundary = b'--' + boundary self._parse_options = parse_options - def __aiter__(self): + def __aiter__(self) -> AsyncIterator[BodyPart]: return self._iterate_parts() - async def _iterate_parts(self): + async def _iterate_parts(self) -> AsyncIterator[BodyPart]: prologue = True delimiter = self._dash_boundary stream = self._stream diff --git a/falcon/asgi/reader.py b/falcon/asgi/reader.py index 281e607c4..c2adda791 100644 --- a/falcon/asgi/reader.py +++ b/falcon/asgi/reader.py @@ -17,10 +17,11 @@ from __future__ import annotations import io -from typing import AsyncIterator, List, NoReturn, Optional, Protocol +from typing import AsyncIterator, List, NoReturn, Optional, Protocol, Union from falcon.errors import DelimiterError from falcon.errors import OperationNotAllowed +from falcon.typing import AsyncReadableIO DEFAULT_CHUNK_SIZE = 8192 """Default minimum chunk size for :class:`BufferedReader` (8 KiB).""" @@ -58,7 +59,11 @@ class BufferedReader: _max_join_size: int _source: AsyncIterator[bytes] - def __init__(self, source: AsyncIterator[bytes], chunk_size: Optional[int] = None): + def __init__( + self, + source: Union[AsyncReadableIO, AsyncIterator[bytes]], + chunk_size: Optional[int] = None, + ): self._source = self._iter_normalized(source) self._chunk_size = chunk_size or DEFAULT_CHUNK_SIZE self._max_join_size = self._chunk_size * _MAX_JOIN_CHUNKS @@ -71,7 +76,7 @@ def __init__(self, source: AsyncIterator[bytes], chunk_size: Optional[int] = Non self._iteration_started = False async def _iter_normalized( - self, source: AsyncIterator[bytes] + self, source: Union[AsyncReadableIO, AsyncIterator[bytes]] ) -> AsyncIterator[bytes]: chunk = b'' chunk_size = self._chunk_size @@ -190,7 +195,9 @@ def _trim_buffer(self) -> None: self._buffer_len -= self._buffer_pos self._buffer_pos = 0 - async def _read_from(self, source: AsyncIterator[bytes], size: int = -1) -> bytes: + async def _read_from( + self, source: AsyncIterator[bytes], size: Optional[int] = -1 + ) -> bytes: if size == -1 or size is None: result_bytes = io.BytesIO() async for chunk in source: @@ -290,7 +297,7 @@ async def pipe_until( if consume_delimiter: await self._consume_delimiter(delimiter) - async def read(self, size: int = -1) -> bytes: + async def read(self, size: Optional[int] = -1) -> bytes: return await self._read_from(self._iter_with_buffer(size_hint=size or 0), size) async def readall(self) -> bytes: diff --git a/falcon/media/multipart.py b/falcon/media/multipart.py index 4e08b5306..19add56e8 100644 --- a/falcon/media/multipart.py +++ b/falcon/media/multipart.py @@ -17,13 +17,29 @@ from __future__ import annotations import re -from typing import Any, ClassVar, Dict, Optional, Tuple, Type, TYPE_CHECKING +from typing import ( + Any, + ClassVar, + Dict, + Iterator, + NoReturn, + Optional, + overload, + Tuple, + Type, + TYPE_CHECKING, + Union, +) from urllib.parse import unquote_to_bytes from falcon import errors from falcon.errors import MultipartParseError from falcon.media.base import BaseHandler from falcon.stream import BoundedStream +from falcon.typing import AsyncReadableIO +from falcon.typing import MISSING +from falcon.typing import MissingOr +from falcon.typing import ReadableIO from falcon.util import BufferedReader from falcon.util import misc from falcon.util.mediatypes import parse_header @@ -31,6 +47,7 @@ if TYPE_CHECKING: from falcon.asgi.multipart import MultipartForm as AsgiMultipartForm from falcon.media import Handlers + from falcon.util.reader import BufferedReader as PyBufferedReader # TODO(vytas): # * Better support for form-wide charset setting @@ -54,158 +71,57 @@ # TODO(vytas): Consider supporting -charset- stuff. # Does anyone use that (?) class BodyPart: - """Represents a body part in a multipart form. + """Represents a body part in a multipart form in a ASGI application. Note: :class:`BodyPart` is meant to be instantiated directly only by the :class:`MultipartFormHandler` parser. + """ - Attributes: - content_type (str): Value of the Content-Type header, or the multipart - form default ``text/plain`` if the header is missing. - - data (bytes): Property that acts as a convenience alias for - :meth:`~.get_data`. - - - .. tabs:: - - .. tab:: WSGI - - .. code:: python - - # Equivalent to: content = part.get_data() - content = part.data - - .. tab:: ASGI - - The ``await`` keyword must still be added when referencing - the property:: - - # Equivalent to: content = await part.get_data() - content = await part.data - - name(str): The name parameter of the Content-Disposition header. - The value of the "name" parameter is the original field name from - the submitted HTML form. - - .. note:: - According to `RFC 7578, section 4.2 - `__, each part - MUST include a Content-Disposition header field of type - "form-data", where the name parameter is mandatory. - - However, Falcon will not raise any error if this parameter is - missing; the property value will be ``None`` in that case. - - filename (str): File name if the body part is an attached file, and - ``None`` otherwise. - - secure_filename (str): The sanitized version of `filename` using only - the most common ASCII characters for maximum portability and safety - wrt using this name as a filename on a regular file system. - - If `filename` is empty or unset when referencing this property, an - instance of :class:`.MultipartParseError` will be raised. - - See also: :func:`~.secure_filename` - - stream: File-like input object for reading the body part of the - multipart form request, if any. This object provides direct access - to the server's data stream and is non-seekable. The stream is - automatically delimited according to the multipart stream boundary. - - With the exception of being buffered to keep track of the boundary, - the wrapped body part stream interface and behavior mimic - :attr:`Request.bounded_stream ` - (WSGI) and :attr:`Request.stream ` - (ASGI), respectively: - - .. tabs:: - - .. tab:: WSGI - - Reading the whole part content: - - .. code:: python - - data = part.stream.read() - - This is also safe: - - .. code:: python - - doc = yaml.safe_load(part.stream) - - .. tab:: ASGI - - Similarly to - :attr:`BoundedStream `, the most - efficient way to read the body part content is asynchronous - iteration over part data chunks: - - .. code:: python - - async for data_chunk in part.stream: - pass - - media (object): Property that acts as a convenience alias for - :meth:`~.get_media`. - - .. tabs:: - - .. tab:: WSGI - - .. code:: python - - # Equivalent to: deserialized_media = part.get_media() - deserialized_media = req.media - - .. tab:: ASGI - - The ``await`` keyword must still be added when referencing - the property:: - - # Equivalent to: deserialized_media = await part.get_media() - deserialized_media = await part.media + _content_disposition: Optional[Tuple[str, Dict[str, str]]] = None + _data: Optional[bytes] = None + _filename: MissingOr[Optional[str]] = MISSING + _media: MissingOr[Any] = MISSING + _name: MissingOr[Optional[str]] = MISSING - text (str): Property that acts as a convenience alias for - :meth:`~.get_text`. + stream: PyBufferedReader + """File-like input object for reading the body part of the + multipart form request, if any. This object provides direct access + to the server's data stream and is non-seekable. The stream is + automatically delimited according to the multipart stream boundary. - .. tabs:: + With the exception of being buffered to keep track of the boundary, + the wrapped body part stream interface and behavior mimic + :attr:`Request.bounded_stream `. - .. tab:: WSGI + Reading the whole part content: - .. code:: python + .. code:: python - # Equivalent to: decoded_text = part.get_text() - decoded_text = part.text + data = part.stream.read() - .. tab:: ASGI + This is also safe: - The ``await`` keyword must still be added when referencing - the property:: + .. code:: python - # Equivalent to: decoded_text = await part.get_text() - decoded_text = await part.text + doc = yaml.safe_load(part.stream) """ - _content_disposition: Optional[Tuple[str, Dict[str, str]]] = None - _data: Optional[bytes] = None - _filename: Optional[str] = None - _media: Optional[Any] = None - _name: Optional[str] = None - - def __init__(self, stream, headers, parse_options): + def __init__( + self, + stream: PyBufferedReader, + headers: Dict[bytes, bytes], + parse_options: MultipartParseOptions, + ): self.stream = stream self._headers = headers self._parse_options = parse_options - def get_data(self): + def get_data(self) -> bytes: """Return the body part content bytes. The maximum number of bytes that may be read is configurable via - :class:`MultipartParseOptions`, and a :class:`.MultipartParseError` is + :class:`.MultipartParseOptions`, and a :class:`.MultipartParseError` is raised if the body part is larger that this size. The size limit guards against reading unexpectedly large amount of data @@ -230,7 +146,7 @@ def get_data(self): return self._data - def get_text(self): + def get_text(self) -> Optional[str]: """Return the body part content decoded as a text string. Text is decoded from the part content (as returned by @@ -268,7 +184,11 @@ def get_text(self): ) from err @property - def content_type(self): + def content_type(self) -> str: + """Value of the Content-Type header. + + When the header is missing returns the multipart form default ``text/plain``. + """ # NOTE(vytas): RFC 7578, section 4.4. # Each part MAY have an (optional) "Content-Type" header field, which # defaults to "text/plain". @@ -276,8 +196,9 @@ def content_type(self): return value.decode('ascii') @property - def filename(self): - if self._filename is None: + def filename(self) -> Optional[str]: + """File name if the body part is an attached file, and ``None`` otherwise.""" + if self._filename is MISSING: if self._content_disposition is None: value = self._headers.get(b'content-disposition', b'') self._content_disposition = parse_header(value.decode()) @@ -288,31 +209,51 @@ def filename(self): # been spotted in the wild, even though RFC 7578 forbids it. match = _FILENAME_STAR_RFC5987.match(params.get('filename*', '')) if match: - charset, value = match.groups() + charset, filename_raw = match.groups() try: - self._filename = unquote_to_bytes(value).decode(charset) + self._filename = unquote_to_bytes(filename_raw).decode(charset) except (ValueError, LookupError) as err: raise MultipartParseError( description='invalid text or charset: {}'.format(charset) ) from err else: - value = params.get('filename') - if value is None: - return None - self._filename = value + self._filename = params.get('filename') return self._filename @property - def secure_filename(self): + def secure_filename(self) -> str: + """The sanitized version of `filename` using only the most common ASCII + characters for maximum portability and safety wrt using this name as a + filename on a regular file system. + + If `filename` is empty or unset when referencing this property, an + instance of :class:`.MultipartParseError` will be raised. + + See also: :func:`~.secure_filename` + """ # noqa: D205 try: return misc.secure_filename(self.filename) except ValueError as ex: raise MultipartParseError(description=str(ex)) from ex @property - def name(self): - if self._name is None: + def name(self) -> Optional[str]: + """The name parameter of the Content-Disposition header. + + The value of the "name" parameter is the original field name from + the submitted HTML form. + + .. note:: + According to `RFC 7578, section 4.2 + `__, each part + MUST include a Content-Disposition header field of type + "form-data", where the name parameter is mandatory. + + However, Falcon will not raise any error if this parameter is + missing; the property value will be ``None`` in that case. + """ + if self._name is MISSING: if self._content_disposition is None: value = self._headers.get(b'content-disposition', b'') self._content_disposition = parse_header(value.decode()) @@ -322,31 +263,21 @@ def name(self): return self._name - def get_media(self): + def get_media(self) -> Any: """Return a deserialized form of the multipart body part. When called, this method will attempt to deserialize the body part stream using the Content-Type header as well as the media-type handlers configured via :class:`MultipartParseOptions`. - .. tabs:: + The result will be cached and returned in subsequent calls:: - .. tab:: WSGI - - The result will be cached and returned in subsequent calls:: - - deserialized_media = part.get_media() - - .. tab:: ASGI - - The result will be cached and returned in subsequent calls:: - - deserialized_media = await part.get_media() + deserialized_media = part.get_media() Returns: object: The deserialized media representation. """ - if self._media is None: + if self._media is MISSING: handler, _, _ = self._parse_options.media_handlers._resolve( self.content_type, 'text/plain' ) @@ -359,24 +290,70 @@ def get_media(self): return self._media - data = property(get_data) - media = property(get_media) - text = property(get_text) + data: bytes = property(get_data) # type: ignore[assignment] + """Property that acts as a convenience alias for :meth:`~.get_data`. + + .. code:: python + + # Equivalent to: content = part.get_data() + content = part.data + """ + media: Any = property(get_media) + """Property that acts as a convenience alias for :meth:`~.get_media`. + + .. code:: python + + # Equivalent to: deserialized_media = part.get_media() + deserialized_media = req.media + """ + text: str = property(get_text) # type: ignore[assignment] + """Property that acts as a convenience alias for :meth:`~.get_text`. + + .. code:: python + + # Equivalent to: decoded_text = part.get_text() + decoded_text = part.text + """ class MultipartForm: - def __init__(self, stream, boundary, content_length, parse_options): + """Iterable object that returns each form part as :class:`BodyPart` instances. + + Typical usage illustrated below:: + + def on_post(self, req: Request, resp: Response) -> None: + form: MultipartForm = req.get_media() + + for part in form: + if part.name == 'foo': + ... + else: + ... + + Note: + :class:`MultipartForm` is meant to be instantiated directly only by the + :class:`MultipartFormHandler` parser. + """ + + def __init__( + self, + stream: ReadableIO, + boundary: bytes, + content_length: Optional[int], + parse_options: MultipartParseOptions, + ) -> None: # NOTE(vytas): More lenient check whether the provided stream is not # already an instance of BufferedReader. # This approach makes testing both the Cythonized and pure-Python # streams easier within the same test/benchmark suite. if not hasattr(stream, 'read_until'): + assert content_length is not None if isinstance(stream, BoundedStream): stream = BufferedReader(stream.stream.read, content_length) else: stream = BufferedReader(stream.read, content_length) - self._stream = stream + self._stream: PyBufferedReader = stream # type: ignore[assignment] self._boundary = boundary # NOTE(vytas): Here self._dash_boundary is not prepended with CRLF # (yet) for parsing the prologue. The CRLF will be prepended later to @@ -385,7 +362,7 @@ def __init__(self, stream, boundary, content_length, parse_options): self._dash_boundary = b'--' + boundary self._parse_options = parse_options - def __iter__(self): + def __iter__(self) -> Iterator[BodyPart]: prologue = True delimiter = self._dash_boundary stream = self._stream @@ -419,7 +396,7 @@ def __iter__(self): description='unexpected form structure' ) from err - headers = {} + headers: Dict[bytes, bytes] = {} try: headers_block = stream.read_until( _CRLF_CRLF, max_headers_size, consume_delimiter=True @@ -480,23 +457,46 @@ class MultipartFormHandler(BaseHandler): over the media object. For examples on parsing the request form, see also: :ref:`multipart`. - - Attributes: - parse_options (MultipartParseOptions): - Configuration options for the multipart form parser and instances - of :class:`~falcon.media.multipart.BodyPart` it yields. - - See also: :ref:`multipart_parser_conf`. """ _ASGI_MULTIPART_FORM: ClassVar[Type[AsgiMultipartForm]] - def __init__(self, parse_options=None): + parse_options: MultipartParseOptions + """Configuration options for the multipart form parser and instances of + :class:`~falcon.media.multipart.BodyPart` it yields. + + See also: :ref:`multipart_parser_conf`. + """ + + def __init__(self, parse_options: Optional[MultipartParseOptions] = None) -> None: self.parse_options = parse_options or MultipartParseOptions() + @overload def _deserialize_form( - self, stream, content_type, content_length, form_cls=MultipartForm - ): + self, + stream: ReadableIO, + content_type: Optional[str], + content_length: Optional[int], + form_cls: Type[MultipartForm] = ..., + ) -> MultipartForm: ... + + @overload + def _deserialize_form( + self, + stream: AsyncReadableIO, + content_type: Optional[str], + content_length: Optional[int], + form_cls: Type[AsgiMultipartForm] = ..., + ) -> AsgiMultipartForm: ... + + def _deserialize_form( + self, + stream: Union[ReadableIO, AsyncReadableIO], + content_type: Optional[str], + content_length: Optional[int], + form_cls: Type[Union[MultipartForm, AsgiMultipartForm]] = MultipartForm, + ) -> Union[MultipartForm, AsgiMultipartForm]: + assert content_type is not None _, options = parse_header(content_type) try: boundary = options['boundary'] @@ -522,17 +522,27 @@ def _deserialize_form( 'Content-Type', ) - return form_cls(stream, boundary.encode(), content_length, self.parse_options) + return form_cls(stream, boundary.encode(), content_length, self.parse_options) # type: ignore[arg-type] - def deserialize(self, stream, content_type, content_length): + def deserialize( + self, + stream: ReadableIO, + content_type: Optional[str], + content_length: Optional[int], + ) -> MultipartForm: return self._deserialize_form(stream, content_type, content_length) - async def deserialize_async(self, stream, content_type, content_length): + async def deserialize_async( + self, + stream: AsyncReadableIO, + content_type: Optional[str], + content_length: Optional[int], + ) -> AsgiMultipartForm: return self._deserialize_form( stream, content_type, content_length, form_cls=self._ASGI_MULTIPART_FORM ) - def serialize(self, media, content_type): + def serialize(self, media: object, content_type: str) -> NoReturn: raise NotImplementedError('multipart form serialization unsupported') @@ -595,7 +605,7 @@ class MultipartParseOptions: 'media_handlers', ) - def __init__(self): + def __init__(self) -> None: self.default_charset = 'utf-8' self.max_body_part_buffer_size = 1024 * 1024 self.max_body_part_count = 64 diff --git a/falcon/request.py b/falcon/request.py index 1641a295f..db89c9470 100644 --- a/falcon/request.py +++ b/falcon/request.py @@ -312,7 +312,7 @@ def __init__( self._cached_uri: Optional[str] = None try: - self.content_type: Union[str, None] = self.env['CONTENT_TYPE'] + self.content_type = self.env['CONTENT_TYPE'] except KeyError: self.content_type = None diff --git a/falcon/typing.py b/falcon/typing.py index 817bf2f8d..e83329a84 100644 --- a/falcon/typing.py +++ b/falcon/typing.py @@ -20,6 +20,7 @@ import http from typing import ( Any, + AsyncIterator, Awaitable, Callable, Dict, @@ -148,6 +149,7 @@ def __call__(self, req: Request, resp: Response, **kwargs: Any) -> None: ... # ASGI class AsyncReadableIO(Protocol): async def read(self, n: Optional[int] = ..., /) -> bytes: ... + def __aiter__(self) -> AsyncIterator[bytes]: ... class AsgiResponderMethod(Protocol): diff --git a/falcon/util/misc.py b/falcon/util/misc.py index 18a27b95e..1fbe09ef3 100644 --- a/falcon/util/misc.py +++ b/falcon/util/misc.py @@ -23,12 +23,14 @@ now = falcon.http_now() """ +from __future__ import annotations + import datetime import functools import http import inspect import re -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import unicodedata from falcon import status_codes @@ -214,7 +216,7 @@ def http_date_to_dt(http_date: str, obs_date: bool = False) -> datetime.datetime def to_query_str( - params: dict, comma_delimited_lists: bool = True, prefix: bool = True + params: Dict[str, Any], comma_delimited_lists: bool = True, prefix: bool = True ) -> str: """Convert a dictionary of parameters to a query string. @@ -370,7 +372,7 @@ def get_http_status( return str(code) + ' ' + default_reason -def secure_filename(filename: str) -> str: +def secure_filename(filename: Optional[str]) -> str: """Sanitize the provided `filename` to contain only ASCII characters. Only ASCII alphanumerals, ``'.'``, ``'-'`` and ``'_'`` are allowed for diff --git a/falcon/util/reader.py b/falcon/util/reader.py index 96adc4b5f..645b35520 100644 --- a/falcon/util/reader.py +++ b/falcon/util/reader.py @@ -120,7 +120,7 @@ def _normalize_size(self, size: Optional[int]) -> int: return max_size return size - def read(self, size: int = -1) -> bytes: + def read(self, size: Optional[int] = -1) -> bytes: return self._read(self._normalize_size(size)) def _read(self, size: int) -> bytes: diff --git a/pyproject.toml b/pyproject.toml index 43eed3128..ef17db4f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,9 +42,7 @@ [[tool.mypy.overrides]] module = [ - "falcon.asgi.multipart", "falcon.asgi.response", - "falcon.media.multipart", "falcon.media.validators.*", "falcon.responders", "falcon.response_helpers", diff --git a/tests/test_media_multipart.py b/tests/test_media_multipart.py index 277c0a567..78e87abca 100644 --- a/tests/test_media_multipart.py +++ b/tests/test_media_multipart.py @@ -291,7 +291,7 @@ def test_body_part_properties(): if part.content_type == 'application/json': assert part.name == part.name == 'document' elif part.name == 'file1': - assert part.filename == part.filename == 'test.txt' + assert part.filename == 'test.txt' assert part.secure_filename == part.filename From 9a4734ad8859725ffd5af33fa15fc10dd4ba6fa5 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Tue, 27 Aug 2024 23:57:02 +0200 Subject: [PATCH 06/12] typing: type response --- falcon/app.py | 14 +- falcon/asgi/app.py | 12 +- falcon/asgi/response.py | 253 ++++++---------- falcon/constants.py | 4 - falcon/media/urlencoded.py | 4 +- falcon/middleware.py | 6 +- falcon/request.py | 6 +- falcon/responders.py | 40 ++- falcon/response.py | 538 ++++++++++++++++++++++------------ falcon/response_helpers.py | 38 ++- falcon/routing/static.py | 6 +- falcon/typing.py | 7 + pyproject.toml | 4 - tests/test_cors_middleware.py | 22 ++ tests/test_headers.py | 3 +- 15 files changed, 563 insertions(+), 394 deletions(-) diff --git a/falcon/app.py b/falcon/app.py index 88ab0e554..b66247051 100644 --- a/falcon/app.py +++ b/falcon/app.py @@ -26,7 +26,6 @@ ClassVar, Dict, FrozenSet, - IO, Iterable, List, Literal, @@ -62,6 +61,7 @@ from falcon.typing import ErrorSerializer from falcon.typing import FindMethod from falcon.typing import ProcessResponseMethod +from falcon.typing import ReadableIO from falcon.typing import ResponderCallable from falcon.typing import SinkCallable from falcon.typing import SinkPrefix @@ -1191,7 +1191,9 @@ def _handle_exception( def _get_body( self, resp: Response, - wsgi_file_wrapper: Optional[Callable[[IO[bytes], int], Iterable[bytes]]] = None, + wsgi_file_wrapper: Optional[ + Callable[[ReadableIO, int], Iterable[bytes]] + ] = None, ) -> Tuple[Iterable[bytes], Optional[int]]: """Convert resp content into an iterable as required by PEP 333. @@ -1229,11 +1231,13 @@ def _get_body( # TODO(kgriffs): Make block size configurable at the # global level, pending experimentation to see how # useful that would be. See also the discussion on - # this GitHub PR: http://goo.gl/XGrtDz - iterable = wsgi_file_wrapper(stream, self._STREAM_BLOCK_SIZE) + # this GitHub PR: + # https://github.com/falconry/falcon/pull/249#discussion_r11269730 + iterable = wsgi_file_wrapper(stream, self._STREAM_BLOCK_SIZE) # type: ignore[arg-type] else: iterable = helpers.CloseableStreamIterator( - stream, self._STREAM_BLOCK_SIZE + stream, # type: ignore[arg-type] + self._STREAM_BLOCK_SIZE, ) else: iterable = stream diff --git a/falcon/asgi/app.py b/falcon/asgi/app.py index 4fe7b7db3..4478728c2 100644 --- a/falcon/asgi/app.py +++ b/falcon/asgi/app.py @@ -46,7 +46,6 @@ from falcon.asgi_spec import AsgiSendMsg from falcon.asgi_spec import EventType from falcon.asgi_spec import WSCloseCode -from falcon.constants import _UNSET from falcon.constants import MEDIA_JSON from falcon.errors import CompatibilityError from falcon.errors import HTTPBadRequest @@ -60,6 +59,7 @@ from falcon.typing import AsgiResponderWsCallable from falcon.typing import AsgiSend from falcon.typing import AsgiSinkCallable +from falcon.typing import MISSING from falcon.typing import SinkPrefix from falcon.util import get_argnames from falcon.util.misc import is_python_func @@ -552,9 +552,9 @@ async def __call__( # type: ignore[override] # noqa: C901 data = resp._data if data is None and resp._media is not None: - # NOTE(kgriffs): We use a special _UNSET singleton since + # NOTE(kgriffs): We use a special MISSING singleton since # None is ambiguous (the media handler might return None). - if resp._media_rendered is _UNSET: + if resp._media_rendered is MISSING: opt = resp.options if not resp.content_type: resp.content_type = opt.default_media_type @@ -577,7 +577,7 @@ async def __call__( # type: ignore[override] # noqa: C901 data = text.encode() except AttributeError: # NOTE(kgriffs): Assume it was a bytes object already - data = text + data = text # type: ignore[assignment] else: # NOTE(vytas): Custom response type. @@ -1028,9 +1028,9 @@ def _schedule_callbacks(self, resp: Response) -> None: loop = asyncio.get_running_loop() - for cb, is_async in callbacks: # type: ignore[attr-defined] + for cb, is_async in callbacks or (): if is_async: - loop.create_task(cb()) + loop.create_task(cb()) # type: ignore[arg-type] else: loop.run_in_executor(None, cb) diff --git a/falcon/asgi/response.py b/falcon/asgi/response.py index b545b5201..204410b47 100644 --- a/falcon/asgi/response.py +++ b/falcon/asgi/response.py @@ -14,11 +14,18 @@ """ASGI Response class.""" +from __future__ import annotations + from inspect import iscoroutine from inspect import iscoroutinefunction +from typing import Awaitable, Callable, List, Literal, Optional, Tuple, Union from falcon import response -from falcon.constants import _UNSET +from falcon.typing import AsyncIterator +from falcon.typing import AsyncReadableIO +from falcon.typing import MISSING +from falcon.typing import ResponseCallbacks +from falcon.typing import SseEmitter from falcon.util.misc import _encode_items_to_latin1 from falcon.util.misc import is_python_func @@ -33,184 +40,109 @@ class Response(response.Response): Keyword Arguments: options (dict): Set of global options passed from the App handler. + """ - Attributes: - status (Union[str,int]): HTTP status code or line (e.g., ``'200 OK'``). - This may be set to a member of :class:`http.HTTPStatus`, an HTTP - status line string or byte string (e.g., ``'200 OK'``), or an - ``int``. - - Note: - The Falcon framework itself provides a number of constants for - common status codes. They all start with the ``HTTP_`` prefix, - as in: ``falcon.HTTP_204``. (See also: :ref:`status`.) - - status_code (int): HTTP status code normalized from :attr:`status`. - When a code is assigned to this property, :attr:`status` is updated, - and vice-versa. The status code can be useful when needing to check - in middleware for codes that fall into a certain class, e.g.:: - - if resp.status_code >= 400: - log.warning(f'returning error response: {resp.status_code}') - - media (object): A serializable object supported by the media handlers - configured via :class:`falcon.RequestOptions`. - - Note: - See also :ref:`media` for more information regarding media - handling. - - text (str): String representing response content. - - Note: - Falcon will encode the given text as UTF-8 - in the response. If the content is already a byte string, - use the :attr:`data` attribute instead (it's faster). - - data (bytes): Byte string representing response content. - - Use this attribute in lieu of `text` when your content is - already a byte string (of type ``bytes``). - - Warning: - Always use the `text` attribute for text, or encode it - first to ``bytes`` when using the `data` attribute, to - ensure Unicode characters are properly encoded in the - HTTP response. - - stream: An async iterator or generator that yields a series of - byte strings that will be streamed to the ASGI server as a - series of "http.response.body" events. Falcon will assume the - body is complete when the iterable is exhausted or as soon as it - yields ``None`` rather than an instance of ``bytes``:: - - async def producer(): - while True: - data_chunk = await read_data() - if not data_chunk: - break - - yield data_chunk - - resp.stream = producer - - Alternatively, a file-like object may be used as long as it - implements an awaitable ``read()`` method:: - - resp.stream = await aiofiles.open('resp_data.bin', 'rb') - - If the object assigned to :attr:`~.stream` holds any resources - (such as a file handle) that must be explicitly released, the - object must implement a ``close()`` method. The ``close()`` method - will be called after exhausting the iterable or file-like object. - - Note: - In order to be compatible with Python 3.7+ and PEP 479, - async iterators must return ``None`` instead of raising - :class:`StopIteration`. This requirement does not - apply to async generators (PEP 525). - - Note: - If the stream length is known in advance, you may wish to - also set the Content-Length header on the response. - - sse (coroutine): A Server-Sent Event (SSE) emitter, implemented as - an async iterator or generator that yields a series of - of :class:`falcon.asgi.SSEvent` instances. Each event will be - serialized and sent to the client as HTML5 Server-Sent Events:: + # PERF(kgriffs): These will be shadowed when set on an instance; let's + # us avoid having to implement __init__ and incur the overhead of + # an additional function call. + _sse: Optional[SseEmitter] = None + _registered_callbacks: Optional[List[ResponseCallbacks]] = None - async def emitter(): - while True: - some_event = await get_next_event() + stream: Union[AsyncReadableIO, AsyncIterator[bytes], None] # type: ignore[assignment] + """An async iterator or generator that yields a series of + byte strings that will be streamed to the ASGI server as a + series of "http.response.body" events. Falcon will assume the + body is complete when the iterable is exhausted or as soon as it + yields ``None`` rather than an instance of ``bytes``:: - if not some_event: - # Send an event consisting of a single "ping" - # comment to keep the connection alive. - yield SSEvent() + async def producer(): + while True: + data_chunk = await read_data() + if not data_chunk: + break - # Alternatively, one can simply yield None and - # a "ping" will also be sent as above. + yield data_chunk - # yield + resp.stream = producer - continue + Alternatively, a file-like object may be used as long as it + implements an awaitable ``read()`` method:: - yield SSEvent(json=some_event, retry=5000) + resp.stream = await aiofiles.open('resp_data.bin', 'rb') - # ...or + If the object assigned to :attr:`~.stream` holds any resources + (such as a file handle) that must be explicitly released, the + object must implement a ``close()`` method. The ``close()`` method + will be called after exhausting the iterable or file-like object. - yield SSEvent(data=b'something', event_id=some_id) + Note: + In order to be compatible with Python 3.7+ and PEP 479, + async iterators must return ``None`` instead of raising + :class:`StopIteration`. This requirement does not + apply to async generators (PEP 525). - # Alternatively, you may yield anything that implements - # a serialize() method that returns a byte string - # conforming to the SSE event stream format. + Note: + If the stream length is known in advance, you may wish to + also set the Content-Length header on the response. + """ - # yield some_event + @property + def sse(self) -> Optional[SseEmitter]: + """A Server-Sent Event (SSE) emitter, implemented as + an async iterator or generator that yields a series of + of :class:`falcon.asgi.SSEvent` instances. Each event will be + serialized and sent to the client as HTML5 Server-Sent Events:: - resp.sse = emitter() + async def emitter(): + while True: + some_event = await get_next_event() - Note: - When the `sse` property is set, it supersedes both the - `text` and `data` properties. + if not some_event: + # Send an event consisting of a single "ping" + # comment to keep the connection alive. + yield SSEvent() - Note: - When hosting an app that emits Server-Sent Events, the web - server should be set with a relatively long keep-alive TTL to - minimize the overhead of connection renegotiations. + # Alternatively, one can simply yield None and + # a "ping" will also be sent as above. - context (object): Empty object to hold any data (in its attributes) - about the response which is specific to your app (e.g. session - object). Falcon itself will not interact with this attribute after - it has been initialized. + # yield - Note: - The preferred way to pass response-specific data, when using the - default context type, is to set attributes directly on the - `context` object. For example:: + continue - resp.context.cache_strategy = 'lru' + yield SSEvent(json=some_event, retry=5000) - context_type (class): Class variable that determines the factory or - type to use for initializing the `context` attribute. By default, - the framework will instantiate bare objects (instances of the bare - :class:`falcon.Context` class). However, you may override this - behavior by creating a custom child class of - :class:`falcon.asgi.Response`, and then passing that new class - to ``falcon.App()`` by way of the latter's `response_type` - parameter. + # ...or - Note: - When overriding `context_type` with a factory function (as - opposed to a class), the function is called like a method of - the current Response instance. Therefore the first argument is - the Response instance itself (self). + yield SSEvent(data=b'something', event_id=some_id) - options (dict): Set of global options passed in from the App handler. + # Alternatively, you may yield anything that implements + # a serialize() method that returns a byte string + # conforming to the SSE event stream format. - headers (dict): Copy of all headers set for the response, - sans cookies. Note that a new copy is created and returned each - time this property is referenced. + # yield some_event - complete (bool): Set to ``True`` from within a middleware method to - signal to the framework that request processing should be - short-circuited (see also :ref:`Middleware `). - """ + resp.sse = emitter() - # PERF(kgriffs): These will be shadowed when set on an instance; let's - # us avoid having to implement __init__ and incur the overhead of - # an additional function call. - _sse = None - _registered_callbacks = None + Note: + When the `sse` property is set, it supersedes both the + `text` and `data` properties. - @property - def sse(self): + Note: + When hosting an app that emits Server-Sent Events, the web + server should be set with a relatively long keep-alive TTL to + minimize the overhead of connection renegotiations. + """ return self._sse @sse.setter - def sse(self, value): + def sse(self, value: Optional[SseEmitter]) -> None: self._sse = value - def set_stream(self, stream, content_length): + def set_stream( + self, + stream: Union[AsyncReadableIO, AsyncIterator[bytes]], # type: ignore[override] + content_length: int, + ) -> None: """Set both `stream` and `content_length`. Although the :attr:`~falcon.asgi.Response.stream` and @@ -241,7 +173,7 @@ def set_stream(self, stream, content_length): # the self.content_length property. self._headers['content-length'] = str(content_length) - async def render_body(self): + async def render_body(self) -> Optional[bytes]: # type: ignore[override] """Get the raw bytestring content for the response body. This coroutine can be awaited to get the raw data for the @@ -261,14 +193,15 @@ async def render_body(self): # NOTE(vytas): The code below is also inlined in asgi.App.__call__. + data: Optional[bytes] text = self.text if text is None: data = self._data if data is None and self._media is not None: - # NOTE(kgriffs): We use a special _UNSET singleton since + # NOTE(kgriffs): We use a special MISSING singleton since # None is ambiguous (the media handler might return None). - if self._media_rendered is _UNSET: + if self._media_rendered is MISSING: if not self.content_type: self.content_type = self.options.default_media_type @@ -290,11 +223,11 @@ async def render_body(self): data = text.encode() except AttributeError: # NOTE(kgriffs): Assume it was a bytes object already - data = text + data = text # type: ignore[assignment] return data - def schedule(self, callback): + def schedule(self, callback: Callable[[], Awaitable[None]]) -> None: """Schedule an async callback to run soon after sending the HTTP response. This method can be used to execute a background job after the response @@ -341,14 +274,14 @@ def schedule(self, callback): # by tests running in a Cython environment, but we can't # detect it with the coverage tool. - rc = (callback, True) + rc: Tuple[Callable[[], Awaitable[None]], Literal[True]] = (callback, True) if not self._registered_callbacks: self._registered_callbacks = [rc] else: self._registered_callbacks.append(rc) - def schedule_sync(self, callback): + def schedule_sync(self, callback: Callable[[], None]) -> None: """Schedule a synchronous callback to run soon after sending the HTTP response. This method can be used to execute a background job after the @@ -387,7 +320,7 @@ def schedule_sync(self, callback): callable. The callback will be called without arguments. """ - rc = (callback, False) + rc: Tuple[Callable[[], None], Literal[False]] = (callback, False) if not self._registered_callbacks: self._registered_callbacks = [rc] @@ -398,7 +331,9 @@ def schedule_sync(self, callback): # Helper methods # ------------------------------------------------------------------------ - def _asgi_headers(self, media_type=None): + def _asgi_headers( + self, media_type: Optional[str] = None + ) -> List[Tuple[bytes, bytes]]: """Convert headers into the format expected by ASGI servers. Header names must be lowercased and both name and value must be diff --git a/falcon/constants.py b/falcon/constants.py index dbbb94934..b1df391d8 100644 --- a/falcon/constants.py +++ b/falcon/constants.py @@ -183,10 +183,6 @@ ) ) -# NOTE(kgriffs): Special singleton to be used internally whenever using -# None would be ambiguous. -_UNSET = object() # TODO: remove once replaced with missing - class WebSocketPayloadType(Enum): """Enum representing the two possible WebSocket payload types.""" diff --git a/falcon/media/urlencoded.py b/falcon/media/urlencoded.py index 1d7f6cb04..ee38391db 100644 --- a/falcon/media/urlencoded.py +++ b/falcon/media/urlencoded.py @@ -52,7 +52,9 @@ def serialize(self, media: Any, content_type: Optional[str] = None) -> bytes: def _deserialize(self, body: bytes) -> Any: try: - # NOTE(kgriffs): According to http://goo.gl/6rlcux the + # NOTE(kgriffs): According to + # https://html.spec.whatwg.org/multipage/form-control-infrastructure.html#application%2Fx-www-form-urlencoded-encoding-algorithm + # the # body should be US-ASCII. Enforcing this also helps # catch malicious input. body_str = body.decode('ascii') diff --git a/falcon/middleware.py b/falcon/middleware.py index 5772e16c7..3ea81a81d 100644 --- a/falcon/middleware.py +++ b/falcon/middleware.py @@ -120,7 +120,11 @@ def process_response( 'Access-Control-Request-Headers', default='*' ) - resp.set_header('Access-Control-Allow-Methods', allow) + if allow is not None: + # NOTE: not sure if it's more appropriate to raise an exception here + # This can happen only when a responder class defines a custom + # on_option responder method and does not set the 'Allow' header. + resp.set_header('Access-Control-Allow-Methods', allow) resp.set_header('Access-Control-Allow-Headers', allow_headers) resp.set_header('Access-Control-Max-Age', '86400') # 24 hours diff --git a/falcon/request.py b/falcon/request.py index db89c9470..db7ad6ea4 100644 --- a/falcon/request.py +++ b/falcon/request.py @@ -81,7 +81,7 @@ class Request: also PEP-3333. Keyword Arguments: - options (dict): Set of global options passed from the App handler. + options (RequestOptions): Set of global options passed from the App handler. """ __slots__ = ( @@ -2368,7 +2368,9 @@ def _parse_form_urlencoded(self) -> None: body_bytes = self.stream.read(content_length) - # NOTE(kgriffs): According to http://goo.gl/6rlcux the + # NOTE(kgriffs): According to + # https://html.spec.whatwg.org/multipage/form-control-infrastructure.html#application%2Fx-www-form-urlencoded-encoding-algorithm + # the # body should be US-ASCII. Enforcing this also helps # catch malicious input. try: diff --git a/falcon/responders.py b/falcon/responders.py index b3ee73763..b28277b47 100644 --- a/falcon/responders.py +++ b/falcon/responders.py @@ -14,33 +14,47 @@ """Default responder implementations.""" +from __future__ import annotations + +from typing import Any, Iterable, NoReturn, TYPE_CHECKING, Union + from falcon.errors import HTTPBadRequest from falcon.errors import HTTPMethodNotAllowed from falcon.errors import HTTPRouteNotFound from falcon.status_codes import HTTP_200 +from falcon.typing import AsgiResponderCallable +from falcon.typing import ResponderCallable + +if TYPE_CHECKING: + from falcon import Request + from falcon import Response + from falcon.asgi import Request as AsgiRequest + from falcon.asgi import Response as AsgiResponse -def path_not_found(req, resp, **kwargs): +def path_not_found(req: Request, resp: Response, **kwargs: Any) -> NoReturn: """Raise 404 HTTPRouteNotFound error.""" raise HTTPRouteNotFound() -async def path_not_found_async(req, resp, **kwargs): +async def path_not_found_async(req: Request, resp: Response, **kwargs: Any) -> NoReturn: """Raise 404 HTTPRouteNotFound error.""" raise HTTPRouteNotFound() -def bad_request(req, resp, **kwargs): +def bad_request(req: Request, resp: Response, **kwargs: Any) -> NoReturn: """Raise 400 HTTPBadRequest error.""" raise HTTPBadRequest(title='Bad request', description='Invalid HTTP method') -async def bad_request_async(req, resp, **kwargs): +async def bad_request_async(req: Request, resp: Response, **kwargs: Any) -> NoReturn: """Raise 400 HTTPBadRequest error.""" raise HTTPBadRequest(title='Bad request', description='Invalid HTTP method') -def create_method_not_allowed(allowed_methods, asgi=False): +def create_method_not_allowed( + allowed_methods: Iterable[str], asgi: bool = False +) -> Union[ResponderCallable, AsgiResponderCallable]: """Create a responder for "405 Method Not Allowed". Args: @@ -52,18 +66,22 @@ def create_method_not_allowed(allowed_methods, asgi=False): if asgi: - async def method_not_allowed_responder_async(req, resp, **kwargs): + async def method_not_allowed_responder_async( + req: AsgiRequest, resp: AsgiResponse, **kwargs: Any + ) -> NoReturn: raise HTTPMethodNotAllowed(allowed_methods) return method_not_allowed_responder_async - def method_not_allowed(req, resp, **kwargs): + def method_not_allowed(req: Request, resp: Response, **kwargs: Any) -> NoReturn: raise HTTPMethodNotAllowed(allowed_methods) return method_not_allowed -def create_default_options(allowed_methods, asgi=False): +def create_default_options( + allowed_methods: Iterable[str], asgi: bool = False +) -> Union[ResponderCallable, AsgiResponderCallable]: """Create a default responder for the OPTIONS method. Args: @@ -76,14 +94,16 @@ def create_default_options(allowed_methods, asgi=False): if asgi: - async def options_responder_async(req, resp, **kwargs): + async def options_responder_async( + req: AsgiRequest, resp: AsgiResponse, **kwargs: Any + ) -> None: resp.status = HTTP_200 resp.set_header('Allow', allowed) resp.set_header('Content-Length', '0') return options_responder_async - def options_responder(req, resp, **kwargs): + def options_responder(req: Request, resp: Response, **kwargs: Any) -> None: resp.status = HTTP_200 resp.set_header('Allow', allowed) resp.set_header('Content-Length', '0') diff --git a/falcon/response.py b/falcon/response.py index bc676c31a..03f15d901 100644 --- a/falcon/response.py +++ b/falcon/response.py @@ -16,13 +16,27 @@ from __future__ import annotations +from datetime import datetime from datetime import timezone import functools import mimetypes -from typing import Dict +from typing import ( + Any, + ClassVar, + Dict, + Iterable, + List, + Mapping, + NoReturn, + Optional, + overload, + Tuple, + Type, + TYPE_CHECKING, + Union, +) from falcon.constants import _DEFAULT_STATIC_MEDIA_TYPES -from falcon.constants import _UNSET from falcon.constants import DEFAULT_MEDIA_TYPE from falcon.errors import HeaderNotSupported from falcon.media import Handlers @@ -32,6 +46,11 @@ from falcon.response_helpers import format_range from falcon.response_helpers import header_property from falcon.response_helpers import is_ascii_encodable +from falcon.typing import Headers +from falcon.typing import MISSING +from falcon.typing import MissingOr +from falcon.typing import RangeSetHeader +from falcon.typing import ReadableIO from falcon.util import dt_to_http from falcon.util import http_cookies from falcon.util import http_status_to_code @@ -41,6 +60,10 @@ from falcon.util.uri import encode_check_escaped as uri_encode from falcon.util.uri import encode_value_check_escaped as uri_encode_value +if TYPE_CHECKING: + import http + + _STREAM_LEN_REMOVED_MSG = ( 'The deprecated stream_len property was removed in Falcon 3.0. ' 'Please use Response.set_stream() or Response.content_length instead.' @@ -58,100 +81,7 @@ class Response: ``Response`` is not meant to be instantiated directly by responders. Keyword Arguments: - options (dict): Set of global options passed from the App handler. - - Attributes: - status (Union[str,int]): HTTP status code or line (e.g., ``'200 OK'``). - This may be set to a member of :class:`http.HTTPStatus`, an HTTP - status line string or byte string (e.g., ``'200 OK'``), or an - ``int``. - - Note: - The Falcon framework itself provides a number of constants for - common status codes. They all start with the ``HTTP_`` prefix, - as in: ``falcon.HTTP_204``. (See also: :ref:`status`.) - - status_code (int): HTTP status code normalized from :attr:`status`. - When a code is assigned to this property, :attr:`status` is updated, - and vice-versa. The status code can be useful when needing to check - in middleware for codes that fall into a certain class, e.g.:: - - if resp.status_code >= 400: - log.warning(f'returning error response: {resp.status_code}') - - media (object): A serializable object supported by the media handlers - configured via :class:`falcon.RequestOptions`. - - Note: - See also :ref:`media` for more information regarding media - handling. - - text (str): String representing response content. - - Note: - Falcon will encode the given text as UTF-8 - in the response. If the content is already a byte string, - use the :attr:`data` attribute instead (it's faster). - - data (bytes): Byte string representing response content. - - Use this attribute in lieu of `text` when your content is - already a byte string (of type ``bytes``). See also the note below. - - Warning: - Always use the `text` attribute for text, or encode it - first to ``bytes`` when using the `data` attribute, to - ensure Unicode characters are properly encoded in the - HTTP response. - - stream: Either a file-like object with a `read()` method that takes - an optional size argument and returns a block of bytes, or an - iterable object, representing response content, and yielding - blocks as byte strings. Falcon will use *wsgi.file_wrapper*, if - provided by the WSGI server, in order to efficiently serve - file-like objects. - - Note: - If the stream is set to an iterable object that requires - resource cleanup, it can implement a close() method to do so. - The close() method will be called upon completion of the request. - - context (object): Empty object to hold any data (in its attributes) - about the response which is specific to your app (e.g. session - object). Falcon itself will not interact with this attribute after - it has been initialized. - - Note: - **New in 2.0:** The default `context_type` (see below) was - changed from :class:`dict` to a bare class; the preferred way to - pass response-specific data is now to set attributes directly - on the `context` object. For example:: - - resp.context.cache_strategy = 'lru' - - context_type (class): Class variable that determines the factory or - type to use for initializing the `context` attribute. By default, - the framework will instantiate bare objects (instances of the bare - :class:`falcon.Context` class). However, you may override this - behavior by creating a custom child class of - :class:`falcon.Response`, and then passing that new class to - ``falcon.App()`` by way of the latter's `response_type` parameter. - - Note: - When overriding `context_type` with a factory function (as - opposed to a class), the function is called like a method of - the current Response instance. Therefore the first argument is - the Response instance itself (self). - - options (dict): Set of global options passed from the App handler. - - headers (dict): Copy of all headers set for the response, - sans cookies. Note that a new copy is created and returned each - time this property is referenced. - - complete (bool): Set to ``True`` from within a middleware method to - signal to the framework that request processing should be - short-circuited (see also :ref:`Middleware `). + options (ResponseOptions): Set of global options passed from the App handler. """ __slots__ = ( @@ -169,12 +99,81 @@ class Response: '__dict__', ) - complete = False + _cookies: Optional[http_cookies.SimpleCookie] + _data: Optional[bytes] + _extra_headers: Optional[List[Tuple[str, str]]] + _headers: Headers + _media: Optional[Any] + _media_rendered: MissingOr[bytes] # Child classes may override this - context_type = structures.Context + context_type: ClassVar[Type[structures.Context]] = structures.Context + """Class variable that determines the factory or + type to use for initializing the `context` attribute. By default, + the framework will instantiate bare objects (instances of the bare + :class:`falcon.Context` class). However, you may override this + behavior by creating a custom child class of + :class:`falcon.Response`, and then passing that new class to + ``falcon.App()`` by way of the latter's `response_type` parameter. + + Note: + When overriding `context_type` with a factory function (as + opposed to a class), the function is called like a method of + the current Response instance. Therefore the first argument is + the Response instance itself (self). + """ + + # Attribute declaration + complete: bool = False + """Set to ``True`` from within a middleware method to signal to the framework that + request processing should be short-circuited (see also + :ref:`Middleware `). + """ + status: Union[str, int, http.HTTPStatus] + """HTTP status code or line (e.g., ``'200 OK'``). + + This may be set to a member of :class:`http.HTTPStatus`, an HTTP status line + string (e.g., ``'200 OK'``), or an ``int``. + + Note: + The Falcon framework itself provides a number of constants for + common status codes. They all start with the ``HTTP_`` prefix, + as in: ``falcon.HTTP_204``. (See also: :ref:`status`.) + """ + text: Optional[str] + """String representing response content. + + Note: + Falcon will encode the given text as UTF-8 in the response. If the content + is already a byte string, use the :attr:`data` attribute instead (it's faster). + """ + stream: Union[ReadableIO, Iterable[bytes], None] + """Either a file-like object with a `read()` method that takes an optional size + argument and returns a block of bytes, or an iterable object, representing response + content, and yielding blocks as byte strings. Falcon will use *wsgi.file_wrapper*, + if provided by the WSGI server, in order to efficiently serve file-like objects. + + Note: + If the stream is set to an iterable object that requires + resource cleanup, it can implement a close() method to do so. + The close() method will be called upon completion of the request. + """ + context: structures.Context + """Empty object to hold any data (in its attributes) about the response which is + specific to your app (e.g. session object). + Falcon itself will not interact with this attribute after it has been initialized. + + Note: + The preferred way to pass response-specific data, when using the + default context type, is to set attributes directly on the + `context` object. For example:: + + resp.context.cache_strategy = 'lru' + """ + options: ResponseOptions + """Set of global options passed in from the App handler.""" - def __init__(self, options=None): + def __init__(self, options: Optional[ResponseOptions] = None) -> None: self.status = '200 OK' self._headers = {} @@ -196,66 +195,98 @@ def __init__(self, options=None): self.stream = None self._data = None self._media = None - self._media_rendered = _UNSET + self._media_rendered = MISSING self.context = self.context_type() @property def status_code(self) -> int: + """HTTP status code normalized from :attr:`status`. + + When a code is assigned to this property, :attr:`status` is updated, + and vice-versa. The status code can be useful when needing to check + in middleware for codes that fall into a certain class, e.g.:: + + if resp.status_code >= 400: + log.warning(f'returning error response: {resp.status_code}') + """ return http_status_to_code(self.status) @status_code.setter - def status_code(self, value): + def status_code(self, value: int) -> None: self.status = value @property - def body(self): + def body(self) -> NoReturn: raise AttributeRemovedError( 'The body attribute is no longer supported. ' 'Please use the text attribute instead.' ) @body.setter - def body(self, value): + def body(self, value: Any) -> NoReturn: raise AttributeRemovedError( 'The body attribute is no longer supported. ' 'Please use the text attribute instead.' ) @property - def data(self): + def data(self) -> Optional[bytes]: + """Byte string representing response content. + + Use this attribute in lieu of `text` when your content is + already a byte string (of type ``bytes``). See also the note below. + + Warning: + Always use the `text` attribute for text, or encode it + first to ``bytes`` when using the `data` attribute, to + ensure Unicode characters are properly encoded in the + HTTP response. + """ return self._data @data.setter - def data(self, value): + def data(self, value: Optional[bytes]) -> None: self._data = value @property - def headers(self): + def headers(self) -> Headers: + """Copy of all headers set for the response, without cookies. + + Note that a new copy is created and returned each time this property is + referenced. + """ return self._headers.copy() @property - def media(self): + def media(self) -> Any: + """A serializable object supported by the media handlers configured via + :class:`falcon.RequestOptions`. + + Note: + See also :ref:`media` for more information regarding media + handling. + """ return self._media @media.setter - def media(self, value): + def media(self, value: Any) -> None: self._media = value - self._media_rendered = _UNSET + self._media_rendered = MISSING @property - def stream_len(self): + def stream_len(self) -> NoReturn: # NOTE(kgriffs): Provide some additional information by raising the # error explicitly. raise AttributeError(_STREAM_LEN_REMOVED_MSG) @stream_len.setter - def stream_len(self, value): + def stream_len(self, value: Any) -> NoReturn: # NOTE(kgriffs): We explicitly disallow setting the deprecated attribute # so that apps relying on it do not fail silently. raise AttributeError(_STREAM_LEN_REMOVED_MSG) - def render_body(self): + def render_body(self) -> Optional[bytes]: """Get the raw bytestring content for the response body. This method returns the raw data for the HTTP response body, taking @@ -272,15 +303,15 @@ def render_body(self): finally the serialized value of the `media` attribute. If none of these attributes are set, ``None`` is returned. """ - + data: Optional[bytes] text = self.text if text is None: data = self._data if data is None and self._media is not None: - # NOTE(kgriffs): We use a special _UNSET singleton since + # NOTE(kgriffs): We use a special MISSING singleton since # None is ambiguous (the media handler might return None). - if self._media_rendered is _UNSET: + if self._media_rendered is MISSING: if not self.content_type: self.content_type = self.options.default_media_type @@ -299,14 +330,16 @@ def render_body(self): data = text.encode() except AttributeError: # NOTE(kgriffs): Assume it was a bytes object already - data = text + data = text # type: ignore[assignment] return data - def __repr__(self): - return '<%s: %s>' % (self.__class__.__name__, self.status) + def __repr__(self) -> str: + return f'<{self.__class__.__name__}: {self.status}>' - def set_stream(self, stream, content_length): + def set_stream( + self, stream: Union[ReadableIO, Iterable[bytes]], content_length: int + ) -> None: """Set both `stream` and `content_length`. Although the :attr:`~falcon.Response.stream` and @@ -338,17 +371,17 @@ def set_stream(self, stream, content_length): def set_cookie( # noqa: C901 self, - name, - value, - expires=None, - max_age=None, - domain=None, - path=None, - secure=None, - http_only=True, - same_site=None, - partitioned=False, - ): + name: str, + value: str, + expires: Optional[datetime] = None, + max_age: Optional[int] = None, + domain: Optional[str] = None, + path: Optional[str] = None, + secure: Optional[bool] = None, + http_only: bool = True, + same_site: Optional[str] = None, + partitioned: bool = False, + ) -> None: """Set a response cookie. Note: @@ -543,7 +576,13 @@ def set_cookie( # noqa: C901 if partitioned: self._cookies[name]['partitioned'] = True - def unset_cookie(self, name, samesite='Lax', domain=None, path=None): + def unset_cookie( + self, + name: str, + samesite: str = 'Lax', + domain: Optional[str] = None, + path: Optional[str] = None, + ) -> None: """Unset a cookie in the response. Clears the contents of the cookie, and instructs the user @@ -619,7 +658,13 @@ def unset_cookie(self, name, samesite='Lax', domain=None, path=None): if path: self._cookies[name]['path'] = path - def get_header(self, name, default=None): + @overload + def get_header(self, name: str, default: str) -> str: ... + + @overload + def get_header(self, name: str, default: Optional[str] = ...) -> Optional[str]: ... + + def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]: """Retrieve the raw string value for the given header. Normally, when a header has multiple values, they will be @@ -651,7 +696,7 @@ def get_header(self, name, default=None): return self._headers.get(name, default) - def set_header(self, name, value): + def set_header(self, name: str, value: str) -> None: """Set a header for this response to a given value. Warning: @@ -687,7 +732,7 @@ def set_header(self, name, value): self._headers[name] = value - def delete_header(self, name): + def delete_header(self, name: str) -> None: """Delete a header that was previously set for this response. If the header was not previously set, nothing is done (no error is @@ -721,7 +766,7 @@ def delete_header(self, name): self._headers.pop(name, None) - def append_header(self, name, value): + def append_header(self, name: str, value: str) -> None: """Set or append a header for this response. If the header already exists, the new value will normally be appended @@ -761,7 +806,9 @@ def append_header(self, name, value): self._headers[name] = value - def set_headers(self, headers): + def set_headers( + self, headers: Union[Mapping[str, str], Iterable[Tuple[str, str]]] + ) -> None: """Set several headers at once. This method can be used to set a collection of raw header names and @@ -802,7 +849,7 @@ def set_headers(self, headers): # normalize the header names. _headers = self._headers - for name, value in headers: + for name, value in headers: # type: ignore[misc] # NOTE(kgriffs): uwsgi fails with a TypeError if any header # is not a str, so do the conversion here. It's actually # faster to not do an isinstance check. str() will encode @@ -817,16 +864,16 @@ def set_headers(self, headers): def append_link( self, - target, - rel, - title=None, - title_star=None, - anchor=None, - hreflang=None, - type_hint=None, - crossorigin=None, - link_extension=None, - ): + target: str, + rel: str, + title: Optional[str] = None, + title_star: Optional[Tuple[str, str]] = None, + anchor: Optional[str] = None, + hreflang: Optional[Union[str, Iterable[str]]] = None, + type_hint: Optional[str] = None, + crossorigin: Optional[str] = None, + link_extension: Optional[Iterable[Tuple[str, str]]] = None, + ) -> None: """Append a link header to the response. (See also: RFC 5988, Section 1) @@ -851,7 +898,7 @@ def append_link( characters, you will need to use `title_star` instead, or provide both a US-ASCII version using `title` and a Unicode version using `title_star`. - title_star (tuple of str): Localized title describing the + title_star (tuple[str, str]): Localized title describing the destination of the link (default ``None``). The value must be a two-member tuple in the form of (*language-tag*, *text*), where *language-tag* is a standard language identifier as @@ -908,40 +955,34 @@ def append_link( if ' ' in rel: rel = '"' + ' '.join([uri_encode(r) for r in rel.split()]) + '"' else: - rel = '"' + uri_encode(rel) + '"' + rel = f'"{uri_encode(rel)}"' value = '<' + uri_encode(target) + '>; rel=' + rel if title is not None: - value += '; title="' + title + '"' + value += f'; title="{title}"' if title_star is not None: - value += ( - "; title*=UTF-8'" - + title_star[0] - + "'" - + uri_encode_value(title_star[1]) - ) + value += f"; title*=UTF-8'{title_star[0]}'{uri_encode_value(title_star[1])}" if type_hint is not None: - value += '; type="' + type_hint + '"' + value += f'; type="{type_hint}"' if hreflang is not None: if isinstance(hreflang, str): - value += '; hreflang=' + hreflang + value += f'; hreflang={hreflang}' else: value += '; ' value += '; '.join(['hreflang=' + lang for lang in hreflang]) if anchor is not None: - value += '; anchor="' + uri_encode(anchor) + '"' + value += f'; anchor="{uri_encode(anchor)}"' if crossorigin is not None: crossorigin = crossorigin.lower() if crossorigin not in _RESERVED_CROSSORIGIN_VALUES: raise ValueError( - 'crossorigin must be set to either ' - "'anonymous' or 'use-credentials'" + "crossorigin must be set to either 'anonymous' or 'use-credentials'" ) if crossorigin == 'anonymous': value += '; crossorigin' @@ -952,11 +993,11 @@ def append_link( if link_extension is not None: value += '; ' - value += '; '.join([p + '=' + v for p, v in link_extension]) + value += '; '.join([f'{p}={v}' for p, v in link_extension]) _headers = self._headers if 'link' in _headers: - _headers['link'] += ', ' + value + _headers['link'] += f', {value}' else: _headers['link'] = value @@ -965,19 +1006,24 @@ def append_link( append_link ) - cache_control = header_property( + cache_control: Union[str, Iterable[str], None] = header_property( 'Cache-Control', """Set the Cache-Control header. Used to set a list of cache directives to use as the value of the Cache-Control header. The list will be joined with ", " to produce the value for the header. - """, format_header_value_list, ) + """Set the Cache-Control header. - content_location = header_property( + Used to set a list of cache directives to use as the value of the + Cache-Control header. The list will be joined with ", " to produce + the value for the header. + """ + + content_location: Optional[str] = header_property( 'Content-Location', """Set the Content-Location header. @@ -987,8 +1033,14 @@ def append_link( """, uri_encode, ) + """Set the Content-Location header. - content_length = header_property( + This value will be URI encoded per RFC 3986. If the value that is + being set is already URI encoded it should be decoded first or the + header should be set manually using the set_header method. + """ + + content_length: Union[str, int, None] = header_property( 'Content-Length', """Set the Content-Length header. @@ -1008,8 +1060,25 @@ def append_link( """, ) + """Set the Content-Length header. + + This property can be used for responding to HEAD requests when you + aren't actually providing the response body, or when streaming the + response. If either the `text` property or the `data` property is set + on the response, the framework will force Content-Length to be the + length of the given text bytes. Therefore, it is only necessary to + manually set the content length when those properties are not used. + + Note: + In cases where the response content is a stream (readable + file-like object), Falcon will not supply a Content-Length header + to the server unless `content_length` is explicitly set. + Consequently, the server may choose to use chunked encoding in this + case. + + """ - content_range = header_property( + content_range: Union[str, RangeSetHeader, None] = header_property( 'Content-Range', """A tuple to use in constructing a value for the Content-Range header. @@ -1029,8 +1098,24 @@ def append_link( """, format_range, ) + """A tuple to use in constructing a value for the Content-Range header. + + The tuple has the form (*start*, *end*, *length*, [*unit*]), where *start* and + *end* designate the range (inclusive), and *length* is the + total length, or '\\*' if unknown. You may pass ``int``'s for + these numbers (no need to convert to ``str`` beforehand). The optional value + *unit* describes the range unit and defaults to 'bytes' + + Note: + You only need to use the alternate form, 'bytes \\*/1234', for + responses that use the status '416 Range Not Satisfiable'. In this + case, raising ``falcon.HTTPRangeNotSatisfiable`` will do the right + thing. + + (See also: RFC 7233, Section 4.2) + """ - content_type = header_property( + content_type: Optional[str] = header_property( 'Content-Type', """Sets the Content-Type header. @@ -1043,8 +1128,18 @@ def append_link( and ``falcon.MEDIA_GIF``. """, ) + """Sets the Content-Type header. + + The ``falcon`` module provides a number of constants for + common media types, including ``falcon.MEDIA_JSON``, + ``falcon.MEDIA_MSGPACK``, ``falcon.MEDIA_YAML``, + ``falcon.MEDIA_XML``, ``falcon.MEDIA_HTML``, + ``falcon.MEDIA_JS``, ``falcon.MEDIA_TEXT``, + ``falcon.MEDIA_JPEG``, ``falcon.MEDIA_PNG``, + and ``falcon.MEDIA_GIF``. + """ - downloadable_as = header_property( + downloadable_as: Optional[str] = header_property( 'Content-Disposition', """Set the Content-Disposition header using the given filename. @@ -1059,8 +1154,19 @@ def append_link( """, functools.partial(format_content_disposition, disposition_type='attachment'), ) + """Set the Content-Disposition header using the given filename. + + The value will be used for the ``filename`` directive. For example, + given ``'report.pdf'``, the Content-Disposition header would be set + to: ``'attachment; filename="report.pdf"'``. - viewable_as = header_property( + As per `RFC 6266 `_ + recommendations, non-ASCII filenames will be encoded using the + ``filename*`` directive, whereas ``filename`` will contain the US + ASCII fallback. + """ + + viewable_as: Optional[str] = header_property( 'Content-Disposition', """Set an inline Content-Disposition header using the given filename. @@ -1077,8 +1183,21 @@ def append_link( """, functools.partial(format_content_disposition, disposition_type='inline'), ) + """Set an inline Content-Disposition header using the given filename. + + The value will be used for the ``filename`` directive. For example, + given ``'report.pdf'``, the Content-Disposition header would be set + to: ``'inline; filename="report.pdf"'``. - etag = header_property( + As per `RFC 6266 `_ + recommendations, non-ASCII filenames will be encoded using the + ``filename*`` directive, whereas ``filename`` will contain the US + ASCII fallback. + + .. versionadded:: 3.1 + """ + + etag: Optional[str] = header_property( 'ETag', """Set the ETag header. @@ -1087,8 +1206,13 @@ def append_link( """, format_etag_header, ) + """Set the ETag header. + + The ETag header will be wrapped with double quotes ``"value"`` in case + the user didn't pass it. + """ - expires = header_property( + expires: Union[str, datetime, None] = header_property( 'Expires', """Set the Expires header. Set to a ``datetime`` (UTC) instance. @@ -1097,8 +1221,13 @@ def append_link( """, dt_to_http, ) + """Set the Expires header. Set to a ``datetime`` (UTC) instance. - last_modified = header_property( + Note: + Falcon will format the ``datetime`` as an HTTP date string. + """ + + last_modified: Union[str, datetime, None] = header_property( 'Last-Modified', """Set the Last-Modified header. Set to a ``datetime`` (UTC) instance. @@ -1107,8 +1236,13 @@ def append_link( """, dt_to_http, ) + """Set the Last-Modified header. Set to a ``datetime`` (UTC) instance. + + Note: + Falcon will format the ``datetime`` as an HTTP date string. + """ - location = header_property( + location: Optional[str] = header_property( 'Location', """Set the Location header. @@ -1118,18 +1252,28 @@ def append_link( """, uri_encode, ) + """Set the Location header. + + This value will be URI encoded per RFC 3986. If the value that is + being set is already URI encoded it should be decoded first or the + header should be set manually using the set_header method. + """ - retry_after = header_property( + retry_after: Union[int, str, None] = header_property( 'Retry-After', """Set the Retry-After header. The expected value is an integral number of seconds to use as the value for the header. The HTTP-date syntax is not supported. """, - str, ) + """Set the Retry-After header. + + The expected value is an integral number of seconds to use as the + value for the header. The HTTP-date syntax is not supported. + """ - vary = header_property( + vary: Union[str, Iterable[str], None] = header_property( 'Vary', """Value to use for the Vary header. @@ -1148,8 +1292,23 @@ def append_link( """, format_header_value_list, ) + """Value to use for the Vary header. - accept_ranges = header_property( + Set this property to an iterable of header names. For a single + asterisk or field value, simply pass a single-element ``list`` + or ``tuple``. + + The "Vary" header field in a response describes what parts of + a request message, aside from the method, Host header field, + and request target, might influence the origin server's + process for selecting and representing this response. The + value consists of either a single asterisk ("*") or a list of + header field names (case-insensitive). + + (See also: RFC 7231, Section 7.1.4) + """ + + accept_ranges: Optional[str] = header_property( 'Accept-Ranges', """Set the Accept-Ranges header. @@ -1167,8 +1326,23 @@ def append_link( """, ) + """Set the Accept-Ranges header. + + The Accept-Ranges header field indicates to the client which + range units are supported (e.g. "bytes") for the target + resource. + + If range requests are not supported for the target resource, + the header may be set to "none" to advise the client not to + attempt any such requests. + + Note: + "none" is the literal string, not Python's built-in ``None`` + type. + + """ - def _set_media_type(self, media_type=None): + def _set_media_type(self, media_type: Optional[str] = None) -> None: """Set a content-type; wrapper around set_header. Args: @@ -1183,7 +1357,7 @@ def _set_media_type(self, media_type=None): if media_type is not None and 'content-type' not in self._headers: self._headers['content-type'] = media_type - def _wsgi_headers(self, media_type=None): + def _wsgi_headers(self, media_type: Optional[str] = None) -> list[tuple[str, str]]: """Convert headers into the format expected by WSGI servers. Args: @@ -1260,7 +1434,7 @@ class ResponseOptions: 'static_media_types', ) - def __init__(self): + def __init__(self) -> None: self.secure_cookies_by_default = True self.default_media_type = DEFAULT_MEDIA_TYPE self.media_handlers = Handlers() diff --git a/falcon/response_helpers.py b/falcon/response_helpers.py index 2e59ba78f..8e1d90207 100644 --- a/falcon/response_helpers.py +++ b/falcon/response_helpers.py @@ -14,11 +14,21 @@ """Utilities for the Response class.""" +from __future__ import annotations + +from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING + +from falcon.typing import RangeSetHeader from falcon.util import uri from falcon.util.misc import secure_filename +if TYPE_CHECKING: + from falcon import Response + -def header_property(name, doc, transform=None): +def header_property( + name: str, doc: str, transform: Optional[Callable[[Any], str]] = None +) -> Any: """Create a header getter/setter. Args: @@ -32,7 +42,7 @@ def header_property(name, doc, transform=None): """ normalized_name = name.lower() - def fget(self): + def fget(self: Response) -> Optional[str]: try: return self._headers[normalized_name] except KeyError: @@ -40,7 +50,7 @@ def fget(self): if transform is None: - def fset(self, value): + def fset(self: Response, value: Optional[Any]) -> None: if value is None: try: del self._headers[normalized_name] @@ -51,7 +61,7 @@ def fset(self, value): else: - def fset(self, value): + def fset(self: Response, value: Optional[Any]) -> None: if value is None: try: del self._headers[normalized_name] @@ -60,31 +70,27 @@ def fset(self, value): else: self._headers[normalized_name] = transform(value) - def fdel(self): + def fdel(self: Response) -> None: del self._headers[normalized_name] return property(fget, fset, fdel, doc) -def format_range(value): +def format_range(value: RangeSetHeader) -> str: """Format a range header tuple per the HTTP spec. Args: value: ``tuple`` passed to `req.range` """ - - # PERF(kgriffs): % was found to be faster than str.format(), - # string concatenation, and str.join() in this case. - if len(value) == 4: - result = '%s %s-%s/%s' % (value[3], value[0], value[1], value[2]) + result = f'{value[3]} {value[0]}-{value[1]}/{value[2]}' else: - result = 'bytes %s-%s/%s' % (value[0], value[1], value[2]) + result = f'bytes {value[0]}-{value[1]}/{value[2]}' return result -def format_content_disposition(value, disposition_type='attachment'): +def format_content_disposition(value: str, disposition_type: str = 'attachment') -> str: """Format a Content-Disposition header given a filename.""" # NOTE(vytas): RFC 6266, Appendix D. @@ -111,7 +117,7 @@ def format_content_disposition(value, disposition_type='attachment'): ) -def format_etag_header(value): +def format_etag_header(value: str) -> str: """Format an ETag header, wrap it with " " in case of need.""" if value[-1] != '"': @@ -120,12 +126,12 @@ def format_etag_header(value): return value -def format_header_value_list(iterable): +def format_header_value_list(iterable: Iterable[str]) -> str: """Join an iterable of strings with commas.""" return ', '.join(iterable) -def is_ascii_encodable(s): +def is_ascii_encodable(s: str) -> bool: """Check if argument encodes to ascii without error.""" try: s.encode('ascii') diff --git a/falcon/routing/static.py b/falcon/routing/static.py index d07af4211..76ef4ed14 100644 --- a/falcon/routing/static.py +++ b/falcon/routing/static.py @@ -249,7 +249,7 @@ async def __call__(self, req: asgi.Request, resp: asgi.Response, **kw: Any) -> N super().__call__(req, resp, **kw) # NOTE(kgriffs): Fixup resp.stream so that it is non-blocking - resp.stream = _AsyncFileReader(resp.stream) + resp.stream = _AsyncFileReader(resp.stream) # type: ignore[assignment,arg-type] class _AsyncFileReader: @@ -259,8 +259,8 @@ def __init__(self, file: IO[bytes]) -> None: self._file = file self._loop = asyncio.get_running_loop() - async def read(self, size=-1): + async def read(self, size: int = -1) -> bytes: return await self._loop.run_in_executor(None, partial(self._file.read, size)) - async def close(self): + async def close(self) -> None: await self._loop.run_in_executor(None, self._file.close) diff --git a/falcon/typing.py b/falcon/typing.py index e83329a84..8e8a70264 100644 --- a/falcon/typing.py +++ b/falcon/typing.py @@ -46,6 +46,7 @@ if TYPE_CHECKING: from falcon.asgi import Request as AsgiRequest from falcon.asgi import Response as AsgiResponse + from falcon.asgi import SSEvent from falcon.asgi import WebSocket from falcon.asgi_spec import AsgiEvent from falcon.asgi_spec import AsgiSendMsg @@ -118,6 +119,7 @@ async def __call__( ResponseStatus = Union[http.HTTPStatus, str, int] StoreArgument = Optional[Dict[str, Any]] Resource = object +RangeSetHeader = Union[Tuple[int, int, int], Tuple[int, int, int, str]] class ResponderMethod(Protocol): @@ -175,6 +177,11 @@ async def __call__( AsgiProcessResourceWsMethod = Callable[ ['AsgiRequest', 'WebSocket', Resource, Dict[str, Any]], Awaitable[None] ] +SseEmitter = AsyncIterator[Optional['SSEvent']] +ResponseCallbacks = Union[ + Tuple[Callable[[], None], Literal[False]], + Tuple[Callable[[], Awaitable[None]], Literal[True]], +] class AsgiResponderCallable(Protocol): diff --git a/pyproject.toml b/pyproject.toml index ef17db4f0..0e1186646 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,11 +42,7 @@ [[tool.mypy.overrides]] module = [ - "falcon.asgi.response", "falcon.media.validators.*", - "falcon.responders", - "falcon.response_helpers", - "falcon.response", "falcon.routing.*", "falcon.routing.converters", "falcon.testing.*", diff --git a/tests/test_cors_middleware.py b/tests/test_cors_middleware.py index 244d22398..7a978d592 100644 --- a/tests/test_cors_middleware.py +++ b/tests/test_cors_middleware.py @@ -28,6 +28,12 @@ def on_delete(self, req, resp): resp.text = "I'm a CORS test response" +class CORSOptionsResource: + def on_options(self, req, resp): + # No allow header set + resp.set_header('Content-Length', '0') + + class TestCorsMiddleware: def test_disabled_cors_should_not_add_any_extra_headers(self, client): client.app.add_route('/', CORSHeaderResource()) @@ -80,6 +86,22 @@ def test_enabled_cors_handles_preflighting(self, cors_client): result.headers['Access-Control-Max-Age'] == '86400' ) # 24 hours in seconds + def test_enabled_cors_handles_preflighting_custom_option(self, cors_client): + cors_client.app.add_route('/', CORSOptionsResource()) + result = cors_client.simulate_options( + headers=( + ('Origin', 'localhost'), + ('Access-Control-Request-Method', 'GET'), + ('Access-Control-Request-Headers', 'X-PINGOTHER, Content-Type'), + ) + ) + assert 'Access-Control-Allow-Methods' not in result.headers + assert ( + result.headers['Access-Control-Allow-Headers'] + == 'X-PINGOTHER, Content-Type' + ) + assert result.headers['Access-Control-Max-Age'] == '86400' + def test_enabled_cors_handles_preflighting_no_headers_in_req(self, cors_client): cors_client.app.add_route('/', CORSHeaderResource()) result = cors_client.simulate_options( diff --git a/tests/test_headers.py b/tests/test_headers.py index 67a80e147..f7ba41d72 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -56,7 +56,8 @@ def on_get(self, req, resp): resp.last_modified = self.last_modified resp.retry_after = 3601 - # Relative URI's are OK per http://goo.gl/DbVqR + # Relative URI's are OK per + # https://datatracker.ietf.org/doc/html/rfc7231#section-7.1.2 resp.location = '/things/87' resp.content_location = '/things/78' From ed2fa8782d3432947733ce2e02852f6e162d7036 Mon Sep 17 00:00:00 2001 From: Vytautas Liuolia Date: Fri, 30 Aug 2024 20:57:58 +0200 Subject: [PATCH 07/12] style: fix spelling in multipart.py --- falcon/media/multipart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/falcon/media/multipart.py b/falcon/media/multipart.py index 19add56e8..17a990265 100644 --- a/falcon/media/multipart.py +++ b/falcon/media/multipart.py @@ -71,7 +71,7 @@ # TODO(vytas): Consider supporting -charset- stuff. # Does anyone use that (?) class BodyPart: - """Represents a body part in a multipart form in a ASGI application. + """Represents a body part in a multipart form in an ASGI application. Note: :class:`BodyPart` is meant to be instantiated directly only by the From 719626abca78160ab357dcd2a1f21b8314f25aeb Mon Sep 17 00:00:00 2001 From: Vytautas Liuolia Date: Fri, 30 Aug 2024 21:05:13 +0200 Subject: [PATCH 08/12] style(tests): explain referencing the same property multiple times --- tests/test_media_multipart.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_media_multipart.py b/tests/test_media_multipart.py index 78e87abca..31c7bbff8 100644 --- a/tests/test_media_multipart.py +++ b/tests/test_media_multipart.py @@ -289,9 +289,13 @@ def test_body_part_properties(): for part in form: if part.content_type == 'application/json': + # NOTE(vytas): This is not a typo, but a test that the name + # property can be safely referenced multiple times. assert part.name == part.name == 'document' elif part.name == 'file1': - assert part.filename == 'test.txt' + # NOTE(vytas): This is not a typo, but a test that the filename + # property can be safely referenced multiple times. + assert part.filename == part.filename == 'test.txt' assert part.secure_filename == part.filename From df08ffd85ad230d6148cea81c9f346381e6010a5 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Fri, 30 Aug 2024 21:45:32 +0200 Subject: [PATCH 09/12] style: fix linter errors --- falcon/asgi/response.py | 2 +- falcon/response.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/falcon/asgi/response.py b/falcon/asgi/response.py index 204410b47..73fcfbfbf 100644 --- a/falcon/asgi/response.py +++ b/falcon/asgi/response.py @@ -131,7 +131,7 @@ async def emitter(): When hosting an app that emits Server-Sent Events, the web server should be set with a relatively long keep-alive TTL to minimize the overhead of connection renegotiations. - """ + """ # noqa: D400 D205 return self._sse @sse.setter diff --git a/falcon/response.py b/falcon/response.py index 03f15d901..d79b57655 100644 --- a/falcon/response.py +++ b/falcon/response.py @@ -266,7 +266,7 @@ def media(self) -> Any: Note: See also :ref:`media` for more information regarding media handling. - """ + """ # noqa D205 return self._media @media.setter From 540927f19e9de7fed55e8c0af2fff5da07cbd0dc Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Tue, 10 Sep 2024 19:10:59 +0200 Subject: [PATCH 10/12] chore: revert behavioral change to cors middleware. --- falcon/middleware.py | 6 +----- tests/test_cors_middleware.py | 1 + 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/falcon/middleware.py b/falcon/middleware.py index 3ea81a81d..0e87275ed 100644 --- a/falcon/middleware.py +++ b/falcon/middleware.py @@ -120,11 +120,7 @@ def process_response( 'Access-Control-Request-Headers', default='*' ) - if allow is not None: - # NOTE: not sure if it's more appropriate to raise an exception here - # This can happen only when a responder class defines a custom - # on_option responder method and does not set the 'Allow' header. - resp.set_header('Access-Control-Allow-Methods', allow) + resp.set_header('Access-Control-Allow-Methods', str(allow)) resp.set_header('Access-Control-Allow-Headers', allow_headers) resp.set_header('Access-Control-Max-Age', '86400') # 24 hours diff --git a/tests/test_cors_middleware.py b/tests/test_cors_middleware.py index 7a978d592..4594242be 100644 --- a/tests/test_cors_middleware.py +++ b/tests/test_cors_middleware.py @@ -86,6 +86,7 @@ def test_enabled_cors_handles_preflighting(self, cors_client): result.headers['Access-Control-Max-Age'] == '86400' ) # 24 hours in seconds + @pytest.mark.xfail(reason='will be fixed in 2325') def test_enabled_cors_handles_preflighting_custom_option(self, cors_client): cors_client.app.add_route('/', CORSOptionsResource()) result = cors_client.simulate_options( From 42bedc6292c713a8f6233ede5aad0d1ec27e3b53 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Tue, 10 Sep 2024 20:13:08 +0200 Subject: [PATCH 11/12] feat: improve cors header to properly handle missing allow headers The static resource will now properly supports CORS requests. Fixes: #2325 --- docs/_newsfragments/2325.newandimproved.rst | 4 ++ falcon/app.py | 9 ++++ falcon/middleware.py | 23 ++++++-- falcon/routing/static.py | 12 +++-- tests/test_cors_middleware.py | 58 ++++++++++++++++++--- tests/test_static.py | 16 ++++++ 6 files changed, 109 insertions(+), 13 deletions(-) create mode 100644 docs/_newsfragments/2325.newandimproved.rst diff --git a/docs/_newsfragments/2325.newandimproved.rst b/docs/_newsfragments/2325.newandimproved.rst new file mode 100644 index 000000000..62a61b81e --- /dev/null +++ b/docs/_newsfragments/2325.newandimproved.rst @@ -0,0 +1,4 @@ +The :class:`~CORSMiddleware` now properly handles the missing ``Allow`` +header case, by denying the preflight CORS request. +The static resource has been updated to properly support CORS request, +by allowing GET requests. diff --git a/falcon/app.py b/falcon/app.py index b66247051..6ef6cda25 100644 --- a/falcon/app.py +++ b/falcon/app.py @@ -759,6 +759,15 @@ def add_sink(self, sink: SinkCallable, prefix: SinkPrefix = r'/') -> None: impractical. For example, you might use a sink to create a smart proxy that forwards requests to one or more backend services. + Note: + To support CORS preflight requests when using the default CORS middleware, + either by setting ``App.cors_enable=True`` or by adding the + :class:`~.CORSMiddleware` to the ``App.middleware``, the sink should + set the ``Allow`` header in the request to the allowed + method values when serving an ``OPTIONS`` request. If the ``Allow`` header + is missing from the response, the default CORS middleware will deny the + preflight request. + Args: sink (callable): A callable taking the form ``func(req, resp, **kwargs)``. diff --git a/falcon/middleware.py b/falcon/middleware.py index 0e87275ed..d457a44b8 100644 --- a/falcon/middleware.py +++ b/falcon/middleware.py @@ -17,6 +17,15 @@ class CORSMiddleware(object): * https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS * https://www.w3.org/TR/cors/#resource-processing-model + Note: + Falcon will automatically add OPTIONS responders if they are missing from the + responder instances added to the routes. When providing a custom ``on_options`` + method, the ``Allow`` headers in the response should be set to the allowed + method values. If the ``Allow`` header is missing from the response, + this middleware will deny the preflight request. + + This is also valid when using a sink function. + Keyword Arguments: allow_origins (Union[str, Iterable[str]]): List of origins to allow (case sensitive). The string ``'*'`` acts as a wildcard, matching every origin. @@ -120,9 +129,17 @@ def process_response( 'Access-Control-Request-Headers', default='*' ) - resp.set_header('Access-Control-Allow-Methods', str(allow)) - resp.set_header('Access-Control-Allow-Headers', allow_headers) - resp.set_header('Access-Control-Max-Age', '86400') # 24 hours + if allow is None: + # there is no allow set, remove all access control headers + resp.delete_header('Access-Control-Allow-Methods') + resp.delete_header('Access-Control-Allow-Headers') + resp.delete_header('Access-Control-Max-Age') + resp.delete_header('Access-Control-Expose-Headers') + resp.delete_header('Access-Control-Allow-Origin') + else: + resp.set_header('Access-Control-Allow-Methods', allow) + resp.set_header('Access-Control-Allow-Headers', allow_headers) + resp.set_header('Access-Control-Max-Age', '86400') # 24 hours async def process_response_async(self, *args: Any) -> None: self.process_response(*args) diff --git a/falcon/routing/static.py b/falcon/routing/static.py index 76ef4ed14..cc2a5ab92 100644 --- a/falcon/routing/static.py +++ b/falcon/routing/static.py @@ -183,6 +183,12 @@ def match(self, path: str) -> bool: def __call__(self, req: Request, resp: Response, **kw: Any) -> None: """Resource responder for this route.""" assert not kw + if req.method == 'OPTIONS': + # it's likely a CORS request. Set the allow header to the appropriate value. + resp.set_header('Allow', 'GET') + resp.set_header('Content-Length', '0') + return + without_prefix = req.path[len(self._prefix) :] # NOTE(kgriffs): Check surrounding whitespace and strip trailing @@ -247,9 +253,9 @@ class StaticRouteAsync(StaticRoute): async def __call__(self, req: asgi.Request, resp: asgi.Response, **kw: Any) -> None: # type: ignore[override] super().__call__(req, resp, **kw) - - # NOTE(kgriffs): Fixup resp.stream so that it is non-blocking - resp.stream = _AsyncFileReader(resp.stream) # type: ignore[assignment,arg-type] + if resp.stream is not None: # None when in an option request + # NOTE(kgriffs): Fixup resp.stream so that it is non-blocking + resp.stream = _AsyncFileReader(resp.stream) # type: ignore[assignment,arg-type] class _AsyncFileReader: diff --git a/tests/test_cors_middleware.py b/tests/test_cors_middleware.py index 4594242be..9aff6abf6 100644 --- a/tests/test_cors_middleware.py +++ b/tests/test_cors_middleware.py @@ -1,3 +1,5 @@ +from pathlib import Path + import pytest import falcon @@ -86,8 +88,7 @@ def test_enabled_cors_handles_preflighting(self, cors_client): result.headers['Access-Control-Max-Age'] == '86400' ) # 24 hours in seconds - @pytest.mark.xfail(reason='will be fixed in 2325') - def test_enabled_cors_handles_preflighting_custom_option(self, cors_client): + def test_enabled_cors_handles_preflight_custom_option(self, cors_client): cors_client.app.add_route('/', CORSOptionsResource()) result = cors_client.simulate_options( headers=( @@ -97,11 +98,10 @@ def test_enabled_cors_handles_preflighting_custom_option(self, cors_client): ) ) assert 'Access-Control-Allow-Methods' not in result.headers - assert ( - result.headers['Access-Control-Allow-Headers'] - == 'X-PINGOTHER, Content-Type' - ) - assert result.headers['Access-Control-Max-Age'] == '86400' + assert 'Access-Control-Allow-Headers' not in result.headers + assert 'Access-Control-Max-Age' not in result.headers + assert 'Access-Control-Expose-Headers' not in result.headers + assert 'Access-Control-Allow-Origin' not in result.headers def test_enabled_cors_handles_preflighting_no_headers_in_req(self, cors_client): cors_client.app.add_route('/', CORSHeaderResource()) @@ -117,6 +117,50 @@ def test_enabled_cors_handles_preflighting_no_headers_in_req(self, cors_client): result.headers['Access-Control-Max-Age'] == '86400' ) # 24 hours in seconds + def test_enabled_cors_static_route(self, cors_client): + cors_client.app.add_static_route('/static', Path(__file__).parent) + result = cors_client.simulate_options( + f'/static/{Path(__file__).name}', + headers=( + ('Origin', 'localhost'), + ('Access-Control-Request-Method', 'GET'), + ), + ) + + assert result.headers['Access-Control-Allow-Methods'] == 'GET' + assert result.headers['Access-Control-Allow-Headers'] == '*' + assert result.headers['Access-Control-Max-Age'] == '86400' + assert result.headers['Access-Control-Allow-Origin'] == '*' + + @pytest.mark.parametrize('support_options', [True, False]) + def test_enabled_cors_sink_route(self, cors_client, support_options): + def my_sink(req, resp): + if req.method == 'OPTIONS' and support_options: + resp.set_header('ALLOW', 'GET') + else: + resp.text = 'my sink' + + cors_client.app.add_sink(my_sink, '/sink') + result = cors_client.simulate_options( + '/sink/123', + headers=( + ('Origin', 'localhost'), + ('Access-Control-Request-Method', 'GET'), + ), + ) + + if support_options: + assert result.headers['Access-Control-Allow-Methods'] == 'GET' + assert result.headers['Access-Control-Allow-Headers'] == '*' + assert result.headers['Access-Control-Max-Age'] == '86400' + assert result.headers['Access-Control-Allow-Origin'] == '*' + else: + assert 'Access-Control-Allow-Methods' not in result.headers + assert 'Access-Control-Allow-Headers' not in result.headers + assert 'Access-Control-Max-Age' not in result.headers + assert 'Access-Control-Expose-Headers' not in result.headers + assert 'Access-Control-Allow-Origin' not in result.headers + @pytest.fixture(scope='function') def make_cors_client(asgi, util): diff --git a/tests/test_static.py b/tests/test_static.py index cfeffc3b3..1b38d2e64 100644 --- a/tests/test_static.py +++ b/tests/test_static.py @@ -617,3 +617,19 @@ def test_file_closed(client, patch_open): assert patch_open.current_file is not None assert patch_open.current_file.closed + + +def test_options_request(util, asgi, patch_open): + patch_open() + app = util.create_app(asgi, cors_enable=True) + app.add_static_route('/static', '/var/www/statics') + client = testing.TestClient(app) + + resp = client.simulate_options( + path='/static/foo/bar.txt', + headers={'Origin': 'localhost', 'Access-Control-Request-Method': 'GET'}, + ) + assert resp.status_code == 200 + assert resp.text == '' + assert int(resp.headers['Content-Length']) == 0 + assert resp.headers['Access-Control-Allow-Methods'] == 'GET' From e571f062ac935afe6c5356d64e63be39e8e27839 Mon Sep 17 00:00:00 2001 From: Vytautas Liuolia Date: Mon, 9 Sep 2024 07:40:22 +0200 Subject: [PATCH 12/12] chore: do not build rapidjson on PyPy --- requirements/tests | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements/tests b/requirements/tests index 5be00de33..ada7c3729 100644 --- a/requirements/tests +++ b/requirements/tests @@ -22,7 +22,8 @@ mujson ujson # it's slow to compile on emulated architectures; wheels missing for some EoL interpreters -python-rapidjson; platform_machine != 's390x' and platform_machine != 'aarch64' +# (and there is a new issue with building on PyPy in Actions, but we don't really need to test it with PyPy) +python-rapidjson; platform_python_implementation != 'PyPy' and platform_machine != 's390x' and platform_machine != 'aarch64' # wheels are missing some EoL interpreters and non-x86 platforms; build would fail unless rust is available orjson; platform_python_implementation != 'PyPy' and platform_machine != 's390x' and platform_machine != 'aarch64'