diff --git a/frontik/handler.py b/frontik/handler.py index 64bfcf64f..7c94f3b6a 100644 --- a/frontik/handler.py +++ b/frontik/handler.py @@ -99,8 +99,8 @@ class PageHandler(RequestHandler): def __init__(self, application: FrontikApplication, request: HTTPServerRequest, **kwargs: Any) -> None: self.name = self.__class__.__name__ - self.request = request_context.get_request() # type: ignore - self.request_id = request.request_id = request_context.get_request_id() # type: ignore + self.request_id: str | None = request_context.get_request_id() + request.request_id = self.request_id # type: ignore self.config = application.config self.log = handler_logger self.text: Any = None @@ -495,18 +495,19 @@ async def _postprocess(self) -> Any: return postprocessed_result def on_connection_close(self): - request_context.initialize(self.request, self.request_id) # type: ignore - - super().on_connection_close() + token = request_context.initialize(self.request, self.request_id) + try: + super().on_connection_close() - self.finish_group.abort() - self.set_status(CLIENT_CLOSED_REQUEST, 'Client closed the connection: aborting request') + self.finish_group.abort() + self.set_status(CLIENT_CLOSED_REQUEST, 'Client closed the connection: aborting request') - self.stages_logger.commit_stage('page') - self.stages_logger.flush_stages(self.get_status()) + self.stages_logger.commit_stage('page') + self.stages_logger.flush_stages(self.get_status()) - super().finish() - self.cleanup() + self.finish() + finally: + request_context.reset(token) def on_finish(self): self.stages_logger.commit_stage('flush') diff --git a/frontik/request_context.py b/frontik/request_context.py index a650dd2b4..b73806910 100644 --- a/frontik/request_context.py +++ b/frontik/request_context.py @@ -13,8 +13,8 @@ class _Context: __slots__ = ('request', 'request_id', 'handler_name', 'log_handler') def __init__(self, request: HTTPServerRequest | None, request_id: str | None) -> None: - self.request = request - self.request_id = request_id + self.request: HTTPServerRequest | None = request + self.request_id: str | None = request_id self.handler_name: str | None = None self.log_handler: DebugBufferedHandler | None = None @@ -22,7 +22,7 @@ def __init__(self, request: HTTPServerRequest | None, request_id: str | None) -> _context = contextvars.ContextVar('context', default=_Context(None, None)) -def initialize(request: HTTPServerRequest, request_id: str) -> contextvars.Token: +def initialize(request: HTTPServerRequest | None, request_id: str | None) -> contextvars.Token: return _context.set(_Context(request, request_id)) @@ -30,7 +30,7 @@ def reset(token: contextvars.Token) -> None: _context.reset(token) -def get_request(): +def get_request() -> HTTPServerRequest | None: return _context.get().request