diff --git a/frontik/app.py b/frontik/app.py index b0ca803de..b4dded341 100644 --- a/frontik/app.py +++ b/frontik/app.py @@ -10,6 +10,9 @@ from functools import partial from threading import Lock from typing import Any, Optional, Union +import json +import inspect +import re from aiokafka import AIOKafkaProducer from http_client import AIOHttpClientWrapper, HttpClientFactory @@ -24,33 +27,118 @@ import frontik.producers.xml_producer from frontik import integrations, media_types, request_context from frontik.debug import DebugTransform, get_frontik_and_apps_versions -from frontik.handler import ErrorHandler, PageHandler +from frontik.handler import PageHandler, FinishPageSignal, RedirectPageSignal, build_error_data from frontik.handler_return_values import ReturnedValueHandlers, get_default_returned_value_handlers from frontik.integrations.statsd import StatsDClient, StatsDClientStub, create_statsd_client from frontik.loggers import CUSTOM_JSON_EXTRA, JSON_REQUESTS_LOGGER from frontik.options import options from frontik.process import WorkerState -from frontik.routing import FileMappingRouter, FrontikRouter +from frontik.routing import routers, normal_routes, regex_mapping, FrontikRouter, FrontikRegexRouter from frontik.service_discovery import UpstreamManager from frontik.util import check_request_id, generate_uniq_timestamp_request_id - -app_logger = logging.getLogger('http_client') - - -class VersionHandler(RequestHandler): - def get(self): - self.application: FrontikApplication - self.set_header('Content-Type', 'text/xml') - self.write( - etree.tostring(get_frontik_and_apps_versions(self.application), encoding='utf-8', xml_declaration=True), - ) - - -class StatusHandler(RequestHandler): - def get(self): - self.application: FrontikApplication - self.set_header('Content-Type', media_types.APPLICATION_JSON) - self.finish(self.application.get_current_status()) +from fastapi import FastAPI, APIRouter, Request +from fastapi.routing import APIRoute +import pkgutil +from http_client import HttpClient +from starlette.middleware.base import Response +from fastapi import Depends +import os +from inspect import ismodule +from starlette.datastructures import MutableHeaders +from frontik.json_builder import json_decode +from frontik.handler import get_current_handler + +app_logger = logging.getLogger('app_logger') + +_core_router = FrontikRouter() +router = FrontikRouter() +regex_router = FrontikRegexRouter() +routers.extend((_core_router, router, regex_router)) + + +def setup_page_handler(request: Request, cls: type(PageHandler)): + # create legacy PageHandler and put to request + handler = cls( + request.app.frontik_app, + request.query_params, + request.cookies, + request.headers, + request.state.body_bytes, + request.state.start_time, + request.url.path, + request.state.path_params, + request.client.host, + request.method, + ) + + request.state.handler = handler + return handler + + +def _data_to_chunk(data, headers) -> bytes: + if isinstance(data, str): + chunk = data.encode("utf-8") + elif isinstance(data, dict): + chunk = json.dumps(data).replace("", "<\\/") + chunk = chunk.encode("utf-8") + headers["Content-Type"] = "application/json; charset=UTF-8" + elif isinstance(data, bytes): + chunk = data + else: + raise RuntimeError('unexpected type of chunk') + return chunk + + +async def core_middleware(request: Request, call_next): + request.state.start_time = time.time() + request.state.body_bytes = await request.body() + + request_id = request.headers.get('X-Request-Id') or FrontikApplication.next_request_id() + if options.validate_request_id: + check_request_id(request_id) + + with request_context.request_context(request, request_id): + route = normal_routes.get((request.url.path, request.method)) + if route is None: + return await call_next(request) # если в нормальных не нашли, пусть фолбэчнется на регекс роутер + + page_cls = route[1] # from normal router + request.state.path_params = {} + setup_page_handler(request, page_cls) + + _call_next = route[0].get_route_handler() + response = await process_request(request, _call_next) + return response + + +async def regex_router_fallback(request: Request, _): + for pattern, route, cls in regex_mapping: + route: APIRoute + match = pattern.match(request.url.path) + if match and next(iter(route.methods), None) == request.method: + request.state.path_params = match.groupdict() + setup_page_handler(request, cls) + call_next = route.get_route_handler() + response = await process_request(request, call_next) + return response + + rid = request_context.get_request_id() + status, headers, content = build_error_data(rid, 404, 'Not Found') + return Response(status_code=status, headers=headers, content=content) + + +@_core_router.get('/version', cls=PageHandler) +async def get_version(self=get_current_handler): + self.set_header('Content-Type', 'text/xml') + self.finish( + etree.tostring(get_frontik_and_apps_versions(self.application), encoding='utf-8', xml_declaration=True), + ) + + +@_core_router.get('/status', cls=PageHandler) +async def get_status(self=get_current_handler): + self.set_header('Content-Type', media_types.APPLICATION_JSON) + self.finish(self.application.get_current_status()) class PydevdHandler(RequestHandler): @@ -60,8 +148,8 @@ def get(self): return try: - debugger_ip = self.get_argument('debugger_ip', self.request.remote_ip) - debugger_port = self.get_argument('debugger_port', '32223') + debugger_ip = self.get_query_argument('debugger_ip', self.request.remote_ip) + debugger_port = self.get_query_argument('debugger_port', '32223') self.settrace(debugger_ip, int(debugger_port)) self.trace_page(debugger_ip, debugger_port) @@ -86,7 +174,7 @@ def error_page(self) -> None: self.finish(traceback.format_exc()) -class FrontikApplication(Application): +class FrontikApplication: request_id = '' class DefaultConfig: @@ -107,15 +195,6 @@ def __init__(self, app_root: str, **settings: Any) -> None: self.available_integrations: list[integrations.Integration] = [] self.tornado_http_client: Optional[AIOHttpClientWrapper] = None self.http_client_factory: HttpClientFactory - self.router = FrontikRouter(self) - - core_handlers: list[Any] = [ - (r'/version/?', VersionHandler), - (r'/status/?', StatusHandler), - (r'.*', self.router), - ] - if options.debug: - core_handlers.insert(0, (r'/pydevd/?', PydevdHandler)) self.statsd_client: Union[StatsDClient, StatsDClientStub] = create_statsd_client(options, self) @@ -126,8 +205,6 @@ def __init__(self, app_root: str, **settings: Any) -> None: self.returned_value_handlers: ReturnedValueHandlers = get_default_returned_value_handlers() - super().__init__(core_handlers) - def create_upstream_manager( self, upstreams: dict[str, Upstream], @@ -146,8 +223,6 @@ def create_upstream_manager( self.upstream_manager.send_updates() # initial full state sending async def init(self) -> None: - self.transforms.insert(0, partial(DebugTransform, self)) # type: ignore - self.available_integrations, integration_futures = integrations.load_integrations(self) await asyncio.gather(*[future for future in integration_futures if future]) @@ -182,36 +257,6 @@ async def init(self) -> None: if self.worker_state.single_worker_mode: self.worker_state.master_done.value = True - def find_handler(self, request, **kwargs): - request_id = request.headers.get('X-Request-Id') - if request_id is None: - request_id = FrontikApplication.next_request_id() - if options.validate_request_id: - check_request_id(request_id) - - def wrapped_in_context(func: Callable) -> Callable: - def wrapper(*args, **kwargs): - with request_context.request_context(request, request_id): - return func(*args, **kwargs) - - return wrapper - - delegate: httputil.HTTPMessageDelegate = wrapped_in_context(super().find_handler)(request, **kwargs) - delegate.headers_received = wrapped_in_context(delegate.headers_received) # type: ignore - delegate.data_received = wrapped_in_context(delegate.data_received) # type: ignore - delegate.finish = wrapped_in_context(delegate.finish) # type: ignore - delegate.on_connection_close = wrapped_in_context(delegate.on_connection_close) # type: ignore - - return delegate - - def reverse_url(self, name: str, *args: Any, **kwargs: Any) -> str: - return self.router.reverse_url(name, *args, **kwargs) - - def application_urls(self) -> list[tuple]: - return [('', FileMappingRouter(importlib.import_module(f'{self.app_module}.pages')))] - - def application_404_handler(self, request: HTTPServerRequest) -> tuple[type[PageHandler], dict]: - return ErrorHandler, {'status_code': 404} def application_config(self) -> DefaultConfig: return FrontikApplication.DefaultConfig() @@ -248,19 +293,15 @@ def get_current_status(self) -> dict[str, str]: return {'uptime': uptime_value, 'datacenter': http_client_options.datacenter} - def log_request(self, handler): - if not options.log_json: - super().log_request(handler) - return - - request_time = int(1000.0 * handler.request.request_time()) + def log_request(self, handler, request: Request): + request_time = int(1000.0 * (time.time() - handler.request_start_time)) extra = { - 'ip': handler.request.remote_ip, + 'ip': request.client.host, 'rid': request_context.get_request_id(), 'status': handler.get_status(), 'time': request_time, - 'method': handler.request.method, - 'uri': handler.request.uri, + 'method': request.method, + 'uri': str(request.url), } handler_name = request_context.get_handler_name() @@ -271,3 +312,76 @@ def log_request(self, handler): def get_kafka_producer(self, producer_name: str) -> Optional[AIOKafkaProducer]: # pragma: no cover pass + + +async def process_request(request, call_next): + handler = request.state.handler + status = 200 + headers = {} + content = None + + try: + request_context.set_handler(handler) + + handler.prepare() + handler.stages_logger.commit_stage('prepare') + _response = await call_next(request) + + handler._handler_finished_notification() + await handler.finish_group.get_gathering_future() + await handler.finish_group.get_finish_future() + handler.stages_logger.commit_stage('page') + + render_result = await handler._postprocess() + handler.stages_logger.commit_stage('postprocess') + + headers = handler.resp_headers + status = handler.get_status() + + debug_transform = DebugTransform(request.app.frontik_app, request) + if debug_transform.is_enabled(): + chunk = _data_to_chunk(render_result, headers) + status, headers, render_result = debug_transform.transform_chunk(status, headers, chunk) + + content = render_result + + except FinishPageSignal as finish_ex: + handler._handler_finished_notification() + headers = handler.resp_headers + chunk = _data_to_chunk(finish_ex.data, headers) + status = handler.get_status() + content = chunk + + except RedirectPageSignal as redirect_ex: + handler._handler_finished_notification() + headers = handler.resp_headers + url = redirect_ex.url + status = redirect_ex.status + headers["Location"] = url.encode('utf-8') + + except Exception as ex: + try: + status, headers, content = await handler._handle_request_exception(ex) + except Exception as exc: + app_logger.exception(f'request processing has failed') + status, headers, content = build_error_data(handler.request_id) + + finally: + handler.cleanup() + + if status in (204, 304) or (100 <= status < 200): + for h in ('Content-Encoding', 'Content-Language', 'Content-Type'): + if h in headers: + headers.pop(h) + content = None + + response = Response(status_code=status, headers=headers, content=content) + + for key, values in handler.resp_cookies.items(): + response.set_cookie(key, **values) + + handler.finish_group.abort() + request.app.frontik_app.log_request(handler, request) + handler.on_finish(status) + + return response diff --git a/frontik/debug.py b/frontik/debug.py index 33fdebb1f..a2f783f0b 100644 --- a/frontik/debug.py +++ b/frontik/debug.py @@ -34,6 +34,8 @@ from frontik.options import options from frontik.version import version as frontik_version from frontik.xml_util import dict_to_xml +from fastapi import Request +from starlette.datastructures import Headers if TYPE_CHECKING: from typing import Any @@ -203,7 +205,7 @@ def _params_to_xml(url: str) -> etree.Element: return params -def _headers_to_xml(request_or_response_headers: dict | HTTPHeaders) -> etree.Element: +def _headers_to_xml(request_or_response_headers: dict | Headers) -> etree.Element: headers = etree.Element('headers') for name, value in request_or_response_headers.items(): if name != 'Cookie': @@ -365,52 +367,42 @@ def _produce_one(self, record: logging.LogRecord) -> etree.Element: DEBUG_XSL = os.path.join(os.path.dirname(__file__), 'debug/debug.xsl') -class DebugTransform(OutputTransform): - def __init__(self, application: FrontikApplication, request: HTTPServerRequest) -> None: +# class DebugTransform(OutputTransform): +class DebugTransform: + def __init__(self, application: FrontikApplication, request: Request) -> None: self.application = application - self.request = request + self.request: Request = request def is_enabled(self) -> bool: - return getattr(self.request, '_debug_enabled', False) + return getattr(self.request.state.handler, '_debug_enabled', False) def is_inherited(self) -> bool: - return getattr(self.request, '_debug_inherited', False) + return getattr(self.request.state.handler, '_debug_inherited', False) - def transform_first_chunk(self, status_code, headers, chunk, finishing): + def transform_chunk(self, status_code, original_headers, chunk): if not self.is_enabled(): - return status_code, headers, chunk + return status_code, original_headers, chunk self.status_code = status_code - self.headers = headers + self.headers = original_headers self.chunks = [chunk] if not self.is_inherited(): - headers = HTTPHeaders({'Content-Type': media_types.TEXT_HTML}) + wrap_headers = {'Content-Type': media_types.TEXT_HTML} else: - headers = HTTPHeaders({'Content-Type': media_types.APPLICATION_XML, DEBUG_HEADER_NAME: 'true'}) + wrap_headers = {'Content-Type': media_types.APPLICATION_XML, DEBUG_HEADER_NAME: 'true'} - return 200, headers, self.produce_debug_body(finishing) - - def transform_chunk(self, chunk: bytes, finishing: bool) -> bytes: - if not self.is_enabled(): - return chunk - - self.chunks.append(chunk) - - return self.produce_debug_body(finishing) - - def produce_debug_body(self, finishing: bool) -> bytes: - if not finishing: - return b'' + return 200, wrap_headers, self.produce_debug_body() + def produce_debug_body(self) -> bytes: start_time = time.time() debug_log_data = request_context.get_log_handler().produce_all() # type: ignore debug_log_data.set('code', str(int(self.status_code))) debug_log_data.set('handler-name', request_context.get_handler_name()) - debug_log_data.set('started', _format_number(self.request._start_time)) - debug_log_data.set('request-id', str(self.request.request_id)) # type: ignore - debug_log_data.set('stages-total', _format_number((time.time() - self.request._start_time) * 1000)) + debug_log_data.set('started', _format_number(self.request.state.start_time)) + debug_log_data.set('request-id', str(self.request.state.handler.request_id)) # type: ignore + debug_log_data.set('stages-total', _format_number((time.time() - self.request.state.start_time) * 1000)) try: debug_log_data.append(E.versions(_pretty_print_xml(get_frontik_and_apps_versions(self.application)))) @@ -427,7 +419,7 @@ def produce_debug_body(self, finishing: bool) -> bytes: debug_log_data.append( E.request( E.method(self.request.method), - _params_to_xml(self.request.uri), # type: ignore + _params_to_xml(str(self.request.url)), # type: ignore _headers_to_xml(self.request.headers), _cookies_to_xml(self.request.headers), # type: ignore ), @@ -451,7 +443,7 @@ def produce_debug_body(self, finishing: bool) -> bytes: upstream.set('bgcolor', bgcolor) upstream.set('fgcolor', fgcolor) - if not getattr(self.request, '_debug_inherited', False): + if not getattr(self.request.state.handler, '_debug_inherited', False): try: transform = etree.XSLT(etree.parse(DEBUG_XSL)) log_document = utf8(str(transform(debug_log_data))) @@ -476,16 +468,16 @@ def __init__(self, handler: PageHandler) -> None: debug_value = frontik.util.get_cookie_or_url_param_value(handler, 'debug') self.mode_values = debug_value.split(',') if debug_value is not None else '' - self.inherited = handler.request.headers.get(DEBUG_HEADER_NAME) + self.inherited = handler.get_header(DEBUG_HEADER_NAME, False) if self.inherited: debug_log.debug('debug mode is inherited due to %s request header', DEBUG_HEADER_NAME) - handler.request._debug_inherited = True # type: ignore + handler._debug_inherited = True # type: ignore if debug_value is not None or self.inherited: handler.require_debug_access() - self.enabled = handler.request._debug_enabled = True # type: ignore + self.enabled = handler._debug_enabled = True # type: ignore self.pass_debug = 'nopass' not in self.mode_values or self.inherited self.profile_xslt = 'xslt' in self.mode_values diff --git a/frontik/futures.py b/frontik/futures.py index cafe7dd66..8558c92df 100644 --- a/frontik/futures.py +++ b/frontik/futures.py @@ -5,6 +5,7 @@ import time from functools import partial, wraps from typing import TYPE_CHECKING, Optional +from asyncio import Task from tornado.concurrent import Future from tornado.ioloop import IOLoop @@ -44,6 +45,9 @@ def is_finished(self) -> bool: return self._finished def abort(self) -> None: + if self._finished: + return + async_logger.info('aborting %s', self) self._finished = True if not self._future.done(): @@ -122,10 +126,9 @@ def _handle_future(callback, future): future.result() callback() - def add_future(self, future: Future) -> Future: - IOLoop.current().add_future(future, partial(self._handle_future, self.add_notification())) - self._futures.append(future) - return future + def add_future(self, task: Task): + task.add_done_callback(partial(self._handle_future, self.add_notification())) + self._futures.append(task) def get_finish_future(self) -> Future: return self._future diff --git a/frontik/handler.py b/frontik/handler.py index f818f33f2..4e06d93bc 100644 --- a/frontik/handler.py +++ b/frontik/handler.py @@ -9,6 +9,7 @@ from asyncio.futures import Future from functools import wraps from typing import TYPE_CHECKING, Any, Optional, Union +import sys import tornado.httputil import tornado.web @@ -18,6 +19,8 @@ from tornado.ioloop import IOLoop from tornado.web import Finish, RequestHandler +import datetime +from tornado.httputil import parse_body_arguments, format_timestamp import frontik.auth import frontik.handler_active_limit import frontik.producers.json_producer @@ -26,7 +29,6 @@ from frontik import media_types, request_context from frontik.auth import DEBUG_AUTH_HEADER_NAME from frontik.debug import DEBUG_HEADER_NAME, DebugMode -from frontik.dependency_manager import APIRouter, execute_page_method_with_dependencies from frontik.futures import AbortAsyncGroup, AsyncGroup from frontik.http_status import ALLOWED_STATUSES, CLIENT_CLOSED_REQUEST from frontik.json_builder import FrontikJsonDecodeError, json_decode @@ -36,6 +38,12 @@ from frontik.util import gather_dict, make_url from frontik.validator import BaseValidationModel, Validators from frontik.version import version as frontik_version +from collections.abc import Callable, Coroutine +from fastapi import Request +from fastapi import HTTPException +from starlette.datastructures import QueryParams, Headers +from fastapi import Depends + if TYPE_CHECKING: from collections.abc import Callable, Coroutine @@ -66,17 +74,31 @@ def __init__(self) -> None: super().__init__(400, 'Failed to parse json in request body') -class DefaultValueError(Exception): - def __init__(self, *args: object) -> None: +class DefaultValueError(tornado.web.HTTPError): + def __init__(self, arg_name: str) -> None: + super().__init__(400, "Missing argument %s" % arg_name) + self.arg_name = arg_name + + +class FinishPageSignal(Exception): + def __init__(self, data: None, *args: object) -> None: + super().__init__(*args) + self.data = data + + +class RedirectPageSignal(Exception): + def __init__(self, url: str, status: int, *args: object) -> None: super().__init__(*args) + self.url = url + self.status = status _ARG_DEFAULT = object() MEDIA_TYPE_PARAMETERS_SEPARATOR_RE = r' *; *' OUTER_TIMEOUT_MS_HEADER = 'X-Outer-Timeout-Ms' +_remove_control_chars_regex = re.compile(r"[\x00-\x08\x0e-\x1f]") handler_logger = logging.getLogger('handler') -router = APIRouter() def _fail_fast_policy(fail_fast: bool, waited: bool, host: str, path: str) -> bool: @@ -91,18 +113,44 @@ def _fail_fast_policy(fail_fast: bool, waited: bool, host: str, path: str) -> bo return fail_fast -class PageHandler(RequestHandler): +class PageHandler: returned_value_handlers: ReturnedValueHandlers = [] - def __init__(self, application: FrontikApplication, request: HTTPServerRequest, **kwargs: Any) -> None: - self.name = self.__class__.__name__ - self.request_id: str = request_context.get_request_id() # type: ignore - request.request_id = self.request_id # type: ignore + def __init__( + self, + application: FrontikApplication, + q_params: QueryParams = None, + c_params: dict[str, str] = None, + h_params: Headers = None, + body_bytes: bytes = None, + request_start_time: float = None, + path: str = None, + path_params: dict = None, + remote_ip: str = None, + method: str = None, + ) -> None: # request: Request + self.q_params = q_params + self.c_params = c_params or {} + self.h_params: Headers = h_params + self.body_bytes = body_bytes + self._json_body = None + self.body_arguments = {} + self.files = {} + self.parse_body_bytes() + self.path = path + self.path_params = path_params + self.request_start_time = request_start_time + self.remote_ip = h_params.get('x-real-ip', None) + self.method = method + + self.resp_cookies: dict[str, dict] = {} + self.config = application.config self.log = handler_logger self.text: Any = None - super().__init__(application, request, **kwargs) + self.application = application + self._finished = False self.statsd_client: StatsDClient | StatsDClientStub @@ -112,32 +160,34 @@ def __init__(self, application: FrontikApplication, request: HTTPServerRequest, if not self.returned_value_handlers: self.returned_value_handlers = list(application.returned_value_handlers) - self.stages_logger = StagesLogger(request, self.statsd_client) + self.stages_logger = StagesLogger(request_start_time, self.statsd_client) self._debug_access: Optional[bool] = None self._render_postprocessors: list = [] self._postprocessors: list = [] - self._mandatory_cookies: dict = {} - self._mandatory_headers = tornado.httputil.HTTPHeaders() - self._validation_model: type[BaseValidationModel | BaseModel] = BaseValidationModel self.timeout_checker = None self.use_adaptive_strategy = False - outer_timeout = request.headers.get(OUTER_TIMEOUT_MS_HEADER) + outer_timeout = h_params.get(OUTER_TIMEOUT_MS_HEADER) if outer_timeout: self.timeout_checker = get_timeout_checker( - request.headers.get(USER_AGENT_HEADER), + h_params.get(USER_AGENT_HEADER), float(outer_timeout), - request.request_time, + request_start_time, ) + self._status = 200 + self._reason = '' + def __repr__(self): return f'{self.__module__}.{self.__class__.__name__}' - def prepare(self) -> None: - self.application: FrontikApplication + def prepare(self): + self.request_id: str = request_context.get_request_id() # type: ignore + self.resp_headers = set_default_headers(self.request_id) + self.active_limit = frontik.handler_active_limit.ActiveHandlersLimit(self.statsd_client) self.debug_mode = DebugMode(self) self.finish_group = AsyncGroup(lambda: None, name='finish') @@ -156,60 +206,99 @@ def prepare(self) -> None: self._handler_finished_notification = self.finish_group.add_notification() - super().prepare() + # Simple getters and setters - def require_debug_access(self, login: Optional[str] = None, passwd: Optional[str] = None) -> None: - if self._debug_access is None: - if options.debug: - debug_access = True - else: - check_login = login if login is not None else options.debug_login - check_passwd = passwd if passwd is not None else options.debug_password - frontik.auth.check_debug_auth(self, check_login, check_passwd) - debug_access = True - - self._debug_access = debug_access - - def set_default_headers(self): - self._headers = tornado.httputil.HTTPHeaders({ - 'Server': f'Frontik/{frontik_version}', - 'X-Request-Id': self.request_id, - }) + def get_request_headers(self) -> Headers: + return self.h_params - def decode_argument(self, value: bytes, name: Optional[str] = None) -> str: - try: - return super().decode_argument(value, name) - except (UnicodeError, tornado.web.HTTPError): - self.log.warning('cannot decode utf-8 query parameter, trying other charsets') - - try: - return frontik.util.decode_string_from_charset(value) - except UnicodeError: - self.log.exception('cannot decode argument, ignoring invalid chars') - return value.decode('utf-8', 'ignore') - - def get_body_argument(self, name: str, default: Any = _ARG_DEFAULT, strip: bool = True) -> Optional[str]: - if self._get_request_mime_type(self.request) == media_types.APPLICATION_JSON: - if name not in self.json_body and default == _ARG_DEFAULT: - raise tornado.web.MissingArgumentError(name) + def get_path_argument(self, name, default=_ARG_DEFAULT): + value = self.path_params.get(name, None) + if value is None: + if default == _ARG_DEFAULT: + raise DefaultValueError(name) + return default + value = _remove_control_chars_regex.sub(" ", value) + return value - result = self.json_body.get(name, default) + def get_query_argument( + self, + name: str, + default: Any = _ARG_DEFAULT, + strip: bool = True, + ) -> Optional[str]: + args = self._get_arguments(name, strip=strip) + if not args: + if default == _ARG_DEFAULT: + raise DefaultValueError(name) + return default + return args[-1] + + def get_query_arguments(self, name: Optional[str] = None, strip: bool = True) -> Union[list[str], dict[str, str]]: + if name is None: + return self._get_all_arguments(strip) + return self._get_arguments(name, strip) + + def _get_all_arguments(self, strip: bool = True) -> dict[str, str]: + qargs_list = self.q_params.multi_items() + values = {} + for qarg_k, qarg_v in qargs_list: + v = _remove_control_chars_regex.sub(" ", qarg_v) + if strip: + v = v.strip() + values[qarg_k] = v + + return values + + def _get_arguments(self, name: str, strip: bool = True) -> list[str]: + qargs_list = self.q_params.multi_items() + values = [] + for qarg_k, qarg_v in qargs_list: + if qarg_k != name: + continue + + # Get rid of any weird control chars (unless decoding gave + # us bytes, in which case leave it alone) + v = _remove_control_chars_regex.sub(" ", qarg_v) + if strip: + v = v.strip() + values.append(v) + + return values - if strip and isinstance(result, str): - return result.strip() + def get_str_argument( + self, + name: str, + default: Any = _ARG_DEFAULT, + path_safe: bool = True, + **kwargs: Any, + ) -> Optional[Union[str, list[str]]]: + if path_safe: + return self.get_validated_argument(name, Validators.PATH_SAFE_STRING, default=default, **kwargs) + return self.get_validated_argument(name, Validators.STRING, default=default, **kwargs) - return result + def get_int_argument( + self, + name: str, + default: Any = _ARG_DEFAULT, + **kwargs: Any, + ) -> Optional[Union[int, list[int]]]: + return self.get_validated_argument(name, Validators.INTEGER, default=default, **kwargs) - if default == _ARG_DEFAULT: - return super().get_body_argument(name, strip=strip) - return super().get_body_argument(name, default, strip) + def get_bool_argument( + self, + name: str, + default: Any = _ARG_DEFAULT, + **kwargs: Any, + ) -> Optional[Union[bool, list[bool]]]: + return self.get_validated_argument(name, Validators.BOOLEAN, default=default, **kwargs) - def set_validation_model(self, model: type[Union[BaseValidationModel, BaseModel]]) -> None: - if issubclass(model, BaseModel): - self._validation_model = model - else: - msg = 'model is not subclass of BaseClass' - raise TypeError(msg) + def get_float_argument( + self, + name: str, + default: Any = _ARG_DEFAULT, + **kwargs: Any, + ) -> Optional[Union[float, list[float]]]: + return self.get_validated_argument(name, Validators.FLOAT, default=default, **kwargs) def get_validated_argument( self, @@ -226,7 +315,7 @@ def get_validated_argument( params = {validator: default} validated_default = self._validation_model(**params).model_dump().get(validator) except ValidationError: - raise DefaultValueError() + raise DefaultValueError(name) else: validated_default = default @@ -236,9 +325,9 @@ def get_validated_argument( elif from_body: value = self.get_body_argument(name, validated_default, strip) elif array: - value = self.get_arguments(name, strip) + value = self.get_query_arguments(name, strip) else: - value = self.get_argument(name, validated_default, strip) + value = self.get_query_argument(name, validated_default, strip) try: params = {validator: value} @@ -250,199 +339,230 @@ def get_validated_argument( return validated_value - def get_str_argument( - self, - name: str, - default: Any = _ARG_DEFAULT, - path_safe: bool = True, - **kwargs: Any, - ) -> Optional[Union[str, list[str]]]: - if path_safe: - return self.get_validated_argument(name, Validators.PATH_SAFE_STRING, default=default, **kwargs) - return self.get_validated_argument(name, Validators.STRING, default=default, **kwargs) + def get_body_arguments(self, name: str = None, strip: bool = True) -> Union[list[str], dict[str, list[str]]]: + # только для не джсона + if name is None: + return self._get_all_body_arguments(strip) + return self._get_body_arguments(name, strip) + + def _get_all_body_arguments(self, strip) -> dict[str, list[str]]: + result = {} + for key, values in self.body_arguments.items(): + result[key] = [] + for v in values: + s = self.decode_argument(v) + if isinstance(s, str): + s = _remove_control_chars_regex.sub(" ", s) + if strip: + s = s.strip() + result[key].append(s) + return result - def get_int_argument( - self, - name: str, - default: Any = _ARG_DEFAULT, - **kwargs: Any, - ) -> Optional[Union[int, list[int]]]: - return self.get_validated_argument(name, Validators.INTEGER, default=default, **kwargs) + def get_body_argument(self, name: str, default: Any = _ARG_DEFAULT, strip: bool = True) -> Optional[str]: + if self._get_request_mime_type() == media_types.APPLICATION_JSON: + if name not in self.json_body and default == _ARG_DEFAULT: + raise DefaultValueError(name) - def get_bool_argument( - self, - name: str, - default: Any = _ARG_DEFAULT, - **kwargs: Any, - ) -> Optional[Union[bool, list[bool]]]: - return self.get_validated_argument(name, Validators.BOOLEAN, default=default, **kwargs) + result = self.json_body.get(name, default) - def get_float_argument( - self, - name: str, - default: Any = _ARG_DEFAULT, - **kwargs: Any, - ) -> Optional[Union[float, list[float]]]: - return self.get_validated_argument(name, Validators.FLOAT, default=default, **kwargs) + if strip and isinstance(result, str): + return result.strip() - def _get_request_mime_type(self, request: HTTPServerRequest) -> str: - content_type = request.headers.get('Content-Type', '') - return re.split(MEDIA_TYPE_PARAMETERS_SEPARATOR_RE, content_type)[0] + return result - def set_status(self, status_code: int, reason: Optional[str] = None) -> None: - status_code = status_code if status_code in ALLOWED_STATUSES else http.client.SERVICE_UNAVAILABLE - super().set_status(status_code, reason=reason) - - def redirect(self, url, *args, allow_protocol_relative=False, **kwargs): - if not allow_protocol_relative and url.startswith('//'): - # A redirect with two initial slashes is a "protocol-relative" URL. - # This means the next path segment is treated as a hostname instead - # of a part of the path, making this effectively an open redirect. - # Reject paths starting with two slashes to prevent this. - # This is only reachable under certain configurations. - raise tornado.web.HTTPError(403, 'cannot redirect path with two initial slashes') - self.log.info('redirecting to: %s', url) - return super().redirect(url, *args, **kwargs) + if default == _ARG_DEFAULT: + return self._get_body_argument(name, strip=strip) + return self._get_body_argument(name, default, strip) - def reverse_url(self, name: str, *args: Any, **kwargs: Any) -> str: - return self.application.reverse_url(name, *args, **kwargs) + def _get_body_argument( + self, + name: str, + default: Any = _ARG_DEFAULT, + strip: bool = True, + ) -> Optional[str]: + args = self._get_body_arguments(name, strip=strip) + if not args: + if default == _ARG_DEFAULT: + raise DefaultValueError(name) + return default + return args[-1] + + def _get_body_arguments(self, name: str, strip: bool = True) -> list[str]: + values = [] + for v in self.body_arguments.get(name, []): + s = self.decode_argument(v, name=name) + if isinstance(s, str): + s = _remove_control_chars_regex.sub(" ", s) + if strip: + s = s.strip() + values.append(s) + return values + + def parse_body_bytes(self): + if self._get_request_mime_type() == media_types.APPLICATION_JSON: # если джсон то парсим сами + # _ = self.json_body + return # on_demand распарсим + else: + parse_body_arguments( + self.get_header('Content-Type', ''), + self.body_bytes, + self.body_arguments, + self.files, + self.h_params, + ) @property def json_body(self): - if not hasattr(self, '_json_body'): + if self._json_body is None: self._json_body = self._get_json_body() return self._json_body def _get_json_body(self) -> Any: try: - return json_decode(self.request.body) + return json_decode(self.body_bytes) except FrontikJsonDecodeError as _: raise JSONBodyParseError() - @classmethod - def add_callback(cls, callback: Callable, *args: Any, **kwargs: Any) -> None: - IOLoop.current().add_callback(callback, *args, **kwargs) - - @classmethod - def add_timeout(cls, deadline: float, callback: Callable, *args: Any, **kwargs: Any) -> Any: - return IOLoop.current().add_timeout(deadline, callback, *args, **kwargs) - - @staticmethod - def remove_timeout(timeout): - IOLoop.current().remove_timeout(timeout) - - @classmethod - def add_future(cls, future: Future, callback: Callable) -> None: - IOLoop.current().add_future(future, callback) - - # Requests handling + def decode_argument(self, value: bytes, name: Optional[str] = None) -> str: + try: + return value.decode("utf-8") + except UnicodeError: + self.log.warning(f'cannot decode utf-8 body parameter {name}, trying other charsets') - async def _execute(self, transforms, *args, **kwargs): - request_context.set_handler(self) try: - return await super()._execute(transforms, *args, **kwargs) - except Exception as ex: - self._handle_request_exception(ex) - return True + return frontik.util.decode_string_from_charset(value) + except UnicodeError: + self.log.exception(f'cannot decode body parameter {name}, ignoring invalid chars') + return value.decode('utf-8', 'ignore') - async def get(self, *args, **kwargs): - await self._execute_page(self.get_page) + def get_header(self, param_name, default=None): + return self.h_params.get(param_name.lower(), default) - async def post(self, *args, **kwargs): - await self._execute_page(self.post_page) + def set_header(self, k, v): + self.resp_headers[k] = v - async def put(self, *args, **kwargs): - await self._execute_page(self.put_page) + def _get_request_mime_type(self) -> str: + content_type = self.get_header('Content-Type', '') + return re.split(MEDIA_TYPE_PARAMETERS_SEPARATOR_RE, content_type)[0] - async def delete(self, *args, **kwargs): - await self._execute_page(self.delete_page) + def clear_header(self, name: str) -> None: + if name in self.resp_headers: + del self.resp_headers[name] - async def head(self, *args, **kwargs): - await self._execute_page(self.get_page) + def clear_cookie(self, name: str, path: str = '/', domain: Optional[str] = None) -> None: # type: ignore + expires = datetime.datetime.utcnow() - datetime.timedelta(days=365) + self.set_cookie(name, value="", expires=expires, path=path, domain=domain) - def options(self, *args, **kwargs): - self.return_405() + def get_cookie(self, param_name, default): + return self.c_params.get(param_name, default) - async def _execute_page(self, page_handler_method: Callable[[], Coroutine[Any, Any, None]]) -> None: - self.stages_logger.commit_stage('prepare') + def set_cookie( + self, + name: str, + value: Union[str, bytes], + domain: Optional[str] = None, + expires: Optional[Union[float, tuple, datetime.datetime]] = None, + path: str = "/", + expires_days: Optional[float] = None, + # Keyword-only args start here for historical reasons. + *, + max_age: Optional[int] = None, + httponly: bool = False, + secure: bool = False, + samesite: Optional[str] = None, + ) -> None: + name = str(name) + value = str(value) + if re.search(r"[\x00-\x20]", name + value): + # Don't let us accidentally inject bad stuff + raise ValueError("Invalid cookie %r: %r" % (name, value)) + + if name in self.resp_cookies: + del self.resp_cookies[name] + self.resp_cookies[name] = {'value': value} + morsel = self.resp_cookies[name] + if domain: + morsel["domain"] = domain + if expires_days is not None and not expires: + expires = datetime.datetime.utcnow() + datetime.timedelta(days=expires_days) + if expires: + morsel["expires"] = format_timestamp(expires) + if path: + morsel["path"] = path + if max_age: + # Note change from _ to -. + morsel["max_age"] = str(max_age) + if httponly: + # Note that SimpleCookie ignores the value here. The presense of an + # httponly (or secure) key is treated as true. + morsel["httponly"] = True + if secure: + morsel["secure"] = True + if samesite: + morsel["samesite"] = samesite - returned_value: ReturnedValue = await execute_page_method_with_dependencies(self, page_handler_method) - for returned_value_handler in self.returned_value_handlers: - returned_value_handler(self, returned_value) + # Requests handling - self._handler_finished_notification() - await self.finish_group.get_gathering_future() - await self.finish_group.get_finish_future() + def require_debug_access(self, login: Optional[str] = None, passwd: Optional[str] = None) -> None: + if self._debug_access is None: + if options.debug: + debug_access = True + else: + check_login = login if login is not None else options.debug_login + check_passwd = passwd if passwd is not None else options.debug_password + frontik.auth.check_debug_auth(self, check_login, check_passwd) + debug_access = True - render_result = await self._postprocess() - if render_result is not None: - self.write(render_result) + self._debug_access = debug_access - @router.get() - async def get_page(self): - """This method can be implemented in the subclass""" - self.return_405() + def set_status(self, status_code: int, reason: Optional[str] = None) -> None: + status_code = status_code if status_code in ALLOWED_STATUSES else http.client.SERVICE_UNAVAILABLE - @router.post() - async def post_page(self): - """This method can be implemented in the subclass""" - self.return_405() + self._status = status_code + self._reason = reason - @router.put() - async def put_page(self): - """This method can be implemented in the subclass""" - self.return_405() + def get_status(self) -> int: + return self._status - @router.delete() - async def delete_page(self): - """This method can be implemented in the subclass""" - self.return_405() + def redirect(self, url: str, permanent: bool = False, status: Optional[int] = None): + if url.startswith('//'): + raise RuntimeError('403 cannot redirect path with two initial slashes') + self.log.info('redirecting to: %s', url) + if status is None: + status = 301 if permanent else 302 + else: + assert isinstance(status, int) and 300 <= status <= 399 + raise RedirectPageSignal(url, status) - def return_405(self) -> None: - allowed_methods = [name for name in ('get', 'post', 'put', 'delete') if f'{name}_page' in vars(self.__class__)] - self.set_header('Allow', ', '.join(allowed_methods)) - self.set_status(405) - self.finish() + def finish(self, data: Optional[Union[str, bytes, dict]] = None) -> Future[None]: + raise FinishPageSignal(data) - def get_page_fail_fast(self, request_result: RequestResult) -> None: - self.__return_error(request_result.status_code, error_info={'is_fail_fast': True}) + async def get_page_fail_fast(self, request_result: RequestResult): + return await self.__return_error(request_result.status_code, error_info={'is_fail_fast': True}) - def post_page_fail_fast(self, request_result: RequestResult) -> None: - self.__return_error(request_result.status_code, error_info={'is_fail_fast': True}) + async def post_page_fail_fast(self, request_result: RequestResult): + return await self.__return_error(request_result.status_code, error_info={'is_fail_fast': True}) - def put_page_fail_fast(self, request_result: RequestResult) -> None: - self.__return_error(request_result.status_code, error_info={'is_fail_fast': True}) + async def put_page_fail_fast(self, request_result: RequestResult): + return await self.__return_error(request_result.status_code, error_info={'is_fail_fast': True}) - def delete_page_fail_fast(self, request_result: RequestResult) -> None: - self.__return_error(request_result.status_code, error_info={'is_fail_fast': True}) + async def delete_page_fail_fast(self, request_result: RequestResult): + return await self.__return_error(request_result.status_code, error_info={'is_fail_fast': True}) - def __return_error(self, response_code: int, **kwargs: Any) -> None: - self.send_error(response_code if 300 <= response_code < 500 else 502, **kwargs) + async def __return_error(self, response_code: int, **kwargs: Any) -> tuple[int, dict, Any]: + return await self.send_error(response_code if 300 <= response_code < 500 else 502, **kwargs) # Finish page def is_finished(self) -> bool: return self._finished - def check_finished(self, callback: Callable) -> Callable: - @wraps(callback) - def wrapper(*args, **kwargs): - if self.is_finished(): - self.log.warning('page was already finished, %s ignored', callback) - else: - return callback(*args, **kwargs) - - return wrapper - - def finish_with_postprocessors(self) -> None: + async def finish_with_postprocessors(self) -> tuple[int, dict, Any]: if not self.finish_group.get_finish_future().done(): self.finish_group.abort() - def _cb(future): - if future.result() is not None: - self.finish(future.result()) - - asyncio.create_task(self._postprocess()).add_done_callback(_cb) + content = await self._postprocess() + return self.get_status(), self.resp_headers, content def run_task(self: PageHandler, coro: Coroutine) -> Task: task = asyncio.create_task(coro) @@ -479,39 +599,33 @@ async def _postprocess(self) -> Any: ) return postprocessed_result - def on_connection_close(self): - with request_context.request_context(self.request, self.request_id): - super().on_connection_close() - - self.finish_group.abort() - self.set_status(CLIENT_CLOSED_REQUEST, 'Client closed the connection: aborting request') - - self.stages_logger.commit_stage('page') - self.stages_logger.flush_stages(self.get_status()) - - self.finish() - - def on_finish(self): + def on_finish(self, status: int): self.stages_logger.commit_stage('flush') - self.stages_logger.flush_stages(self.get_status()) + self.stages_logger.flush_stages(status) - def _handle_request_exception(self, e: BaseException) -> None: + async def _handle_request_exception(self, e: BaseException) -> tuple[int, dict, Any]: if isinstance(e, AbortAsyncGroup): self.log.info('page was aborted, skipping postprocessing') - return + raise e if isinstance(e, FinishWithPostprocessors): if e.wait_finish_group: self._handler_finished_notification() - self.add_future(self.finish_group.get_finish_future(), lambda _: self.finish_with_postprocessors()) - else: - self.finish_with_postprocessors() - return + await self.finish_group.get_finish_future() + return await self.finish_with_postprocessors() + + if isinstance(e, HTTPErrorWithPostprocessors): + self.set_status(e.status_code) + return await self.finish_with_postprocessors() + + if isinstance(e, tornado.web.HTTPError): + self.set_status(e.status_code) + return await self.write_error(e.status_code, exc_info=sys.exc_info()) if self._finished and not isinstance(e, Finish): # tornado will handle Finish by itself # any other errors can't complete after handler is finished - return + raise e if isinstance(e, FailFastError): request = e.failed_result.request @@ -532,51 +646,42 @@ def _handle_request_exception(self, e: BaseException) -> None: ) try: - error_method_name = f'{self.request.method.lower()}_page_fail_fast' # type: ignore + error_method_name = f'{self.method.lower()}_page_fail_fast' # type: ignore method = getattr(self, error_method_name, None) if callable(method): - method(e.failed_result) + return await method(e.failed_result) else: - self.__return_error(e.failed_result.status_code, error_info={'is_fail_fast': True}) + return await self.__return_error(e.failed_result.status_code, error_info={'is_fail_fast': True}) except Exception as exc: - super()._handle_request_exception(exc) + raise exc else: - super()._handle_request_exception(e) + raise e - def send_error(self, status_code: int = 500, **kwargs: Any) -> None: + async def send_error(self, status_code: int = 500, **kwargs: Any) -> tuple[int, dict, Any]: """`send_error` is adapted to support `write_error` that can call `finish` asynchronously. """ self.stages_logger.commit_stage('page') - if self._headers_written: - super().send_error(status_code, **kwargs) - return - - reason = kwargs.get('reason') + self._reason = kwargs.get('reason') if 'exc_info' in kwargs: exception = kwargs['exc_info'][1] if isinstance(exception, tornado.web.HTTPError) and exception.reason: - reason = exception.reason + self._reason = exception.reason else: exception = None if not isinstance(exception, HTTPErrorWithPostprocessors): - self.clear() + set_default_headers(self.request_id) - self.set_status(status_code, reason=reason) + self.set_status(status_code, reason=self._reason) - try: - self.write_error(status_code, **kwargs) - except Exception: - self.log.exception('Uncaught exception in write_error') - if not self._finished: - self.finish() + return await self.write_error(status_code, **kwargs) - def write_error(self, status_code: int = 500, **kwargs: Any) -> None: + async def write_error(self, status_code: int = 500, **kwargs: Any) -> tuple[int, dict, Any]: """ `write_error` can call `finish` asynchronously if HTTPErrorWithPostprocessors is raised. """ @@ -584,62 +689,17 @@ def write_error(self, status_code: int = 500, **kwargs: Any) -> None: exception = kwargs['exc_info'][1] if 'exc_info' in kwargs else None if isinstance(exception, HTTPErrorWithPostprocessors): - self.finish_with_postprocessors() - return + return await self.finish_with_postprocessors() - self.set_header('Content-Type', media_types.TEXT_HTML) - super().write_error(status_code, **kwargs) + return build_error_data(self.request_id, status_code, self._reason) def cleanup(self) -> None: + self._finished = True if hasattr(self, 'active_limit'): self.active_limit.release() - def finish(self, chunk: Optional[Union[str, bytes, dict]] = None) -> Future[None]: - self.stages_logger.commit_stage('postprocess') - for name, value in self._mandatory_headers.items(): - self.set_header(name, value) - - for args, kwargs in self._mandatory_cookies.values(): - try: - self.set_cookie(*args, **kwargs) - except ValueError: - self.set_status(http.client.BAD_REQUEST) - - if self._status_code in (204, 304) or (100 <= self._status_code < 200): - self._write_buffer = [] - chunk = None - - finish_future = super().finish(chunk) - self.cleanup() - return finish_future - # postprocessors - def set_mandatory_header(self, name: str, value: str) -> None: - self._mandatory_headers[name] = value - - def set_mandatory_cookie( - self, - name: str, - value: str, - domain: Optional[str] = None, - expires: Optional[str] = None, - path: str = '/', - expires_days: Optional[int] = None, - **kwargs: Any, - ) -> None: - self._mandatory_cookies[name] = ((name, value, domain, expires, path, expires_days), kwargs) - - def clear_header(self, name: str) -> None: - if name in self._mandatory_headers: - del self._mandatory_headers[name] - super().clear_header(name) - - def clear_cookie(self, name: str, path: str = '/', domain: Optional[str] = None) -> None: # type: ignore - if name in self._mandatory_cookies: - del self._mandatory_cookies[name] - super().clear_cookie(name, path=path, domain=domain) - async def _run_postprocessors(self, postprocessors: list) -> bool: for p in postprocessors: if asyncio.iscoroutinefunction(p): @@ -677,7 +737,7 @@ def add_postprocessor(self, postprocessor: Any) -> None: async def _generic_producer(self): self.log.debug('finishing plaintext') - if self._headers.get('Content-Type') is None: + if self.resp_headers.get('Content-Type') is None: self.set_header('Content-Type', media_types.TEXT_HTML) return self.text, None @@ -708,13 +768,10 @@ def modify_http_client_request(self, balanced_request: RequestBuilder) -> None: balanced_request.path = make_url(balanced_request.path, debug_timestamp=int(time.time())) for header_name in ('Authorization', DEBUG_AUTH_HEADER_NAME): - authorization = self.request.headers.get(header_name) + authorization = self.get_header(header_name) if authorization is not None: balanced_request.headers[header_name] = authorization - def group(self, futures: dict) -> Task: - return self.run_task(gather_dict(coro_dict=futures)) - def get_url( self, host: str, @@ -952,15 +1009,54 @@ def _execute_http_client_method( return future -class ErrorHandler(PageHandler, tornado.web.ErrorHandler): - pass +@Depends +async def get_current_handler(request: Request) -> PageHandler: + return request.state.handler -class RedirectHandler(PageHandler, tornado.web.RedirectHandler): - @router.get() - def get_page(self): - tornado.web.RedirectHandler.get(self) +class RequestCancelledMiddleware: + # https://github.com/tiangolo/fastapi/discussions/11360 + def __init__(self, app): + self.app = app + async def __call__(self, scope, receive, send): + if scope["type"] != "http": + await self.app(scope, receive, send) + return -async def get_current_handler(request: Request) -> PageHandler: - return request['handler'] + queue = asyncio.Queue() + + async def message_poller(sentinel, handler_task): + nonlocal queue + while True: + message = await receive() + if message["type"] == "http.disconnect": + handler_task.cancel() + return sentinel + + await queue.put(message) + + sentinel = object() + handler_task = asyncio.create_task(self.app(scope, queue.get, send)) + poller_task = asyncio.create_task(message_poller(sentinel, handler_task)) + poller_task.done() + + try: + return await handler_task + except asyncio.CancelledError: + pass + # handler_logger.info(f'Cancelling request due to client has disconnected') + + +def set_default_headers(request_id): + return { + 'Server': f'Frontik/{frontik_version}', + 'X-Request-Id': request_id, + } + + +def build_error_data(request_id, status_code: int = 500, message='Internal Server Error') -> tuple[int, dict, Any]: + headers = set_default_headers(request_id) + headers['Content-Type'] = media_types.TEXT_HTML + content = f'