Skip to content

Commit

Permalink
HH-198875 add current_handler dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
bokshitsky committed Nov 19, 2023
1 parent bf5a038 commit 706af53
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 76 deletions.
6 changes: 1 addition & 5 deletions frontik/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,8 @@ def find_handler(self, request, **kwargs):

def wrapped_in_context(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
token = request_context.initialize(request, request_id)

try:
with request_context.request_context(request, request_id):
return func(*args, **kwargs)
finally:
request_context.reset(token)

return wrapper

Expand Down
10 changes: 1 addition & 9 deletions frontik/dependency_manager/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,8 @@ def _register_dependency_params(
sub_dependency = _make_dependency_for_graph(graph, param.default.func, deep_scan)
if deep_scan:
_register_sub_dependency(graph, dependency, sub_dependency, add_to_args)
continue

elif issubclass(graph.handler_cls, param.annotation):
sub_dependency = _make_dependency_for_graph(graph, get_handler, deep_scan)
graph.special_deps.add(sub_dependency)
if deep_scan:
_register_sub_dependency(graph, dependency, sub_dependency, add_to_args)
continue

elif param_name == 'self':
elif issubclass(graph.handler_cls, param.annotation) or param_name == 'self':
sub_dependency = _make_dependency_for_graph(graph, get_handler, deep_scan)
graph.special_deps.add(sub_dependency)
if deep_scan:
Expand Down
7 changes: 2 additions & 5 deletions frontik/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def add_future(cls, future: Future, callback: Callable) -> None:
# Requests handling

async def _execute(self, transforms, *args, **kwargs):
request_context.set_handler_name(repr(self))
request_context.set_handler(self)
try:
return await super()._execute(transforms, *args, **kwargs)
except Exception as ex:
Expand Down Expand Up @@ -476,8 +476,7 @@ async def _postprocess(self) -> Any:
return postprocessed_result

def on_connection_close(self):
token = request_context.initialize(self.request, self.request_id)
try:
with request_context.request_context(self.request, self.request_id):
super().on_connection_close()

self.finish_group.abort()
Expand All @@ -487,8 +486,6 @@ def on_connection_close(self):
self.stages_logger.flush_stages(self.get_status())

self.finish()
finally:
request_context.reset(token)

def on_finish(self):
self.stages_logger.commit_stage('flush')
Expand Down
33 changes: 9 additions & 24 deletions frontik/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import socket
import time
from functools import cache
from logging import Filter, Formatter, Handler
from logging.handlers import SysLogHandler
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -81,14 +82,10 @@ def format(self, record):
if custom_json:
json_message.update(custom_json)
else:
json_message.update(
{
'lvl': record.levelname,
'logger': record.name,
'mdc': mdc,
'msg': message,
},
)
json_message['lvl'] = record.levelname
json_message['logger'] = record.name
json_message['mdc'] = mdc
json_message['msg'] = message

if stack_trace:
json_message['exception'] = stack_trace
Expand Down Expand Up @@ -141,26 +138,14 @@ def format(self, record):
return super().format(record)


_STDERR_FORMATTER = None
_TEXT_FORMATTER = None


@cache
def get_stderr_formatter() -> StderrFormatter:
global _STDERR_FORMATTER

if _STDERR_FORMATTER is None:
_STDERR_FORMATTER = StderrFormatter(fmt=options.stderr_format, datefmt=options.stderr_dateformat)

return _STDERR_FORMATTER
return StderrFormatter(fmt=options.stderr_format, datefmt=options.stderr_dateformat)


@cache
def get_text_formatter() -> Formatter:
global _TEXT_FORMATTER

if _TEXT_FORMATTER is None:
_TEXT_FORMATTER = Formatter(options.log_text_format)

return _TEXT_FORMATTER
return Formatter(options.log_text_format)


def bootstrap_logger(
Expand Down
41 changes: 28 additions & 13 deletions frontik/request_context.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,38 @@
from __future__ import annotations

import contextvars
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Iterator

from tornado.httputil import HTTPServerRequest

from frontik.debug import DebugBufferedHandler
from frontik.handler import PageHandler


@dataclass(slots=True)
class _Context:
__slots__ = ('request', 'request_id', 'handler_name', 'log_handler')

def __init__(self, request: HTTPServerRequest | None, request_id: str | None) -> None:
self.request = request
self.request_id = request_id
self.handler_name: str | None = None
self.log_handler: DebugBufferedHandler | None = None
request: HTTPServerRequest | None
request_id: str | None
handler: PageHandler | None = None
handler_name: str | None = None
log_handler: DebugBufferedHandler | None = None


_context = contextvars.ContextVar('context', default=_Context(None, None))


def initialize(request: HTTPServerRequest, request_id: str) -> contextvars.Token:
return _context.set(_Context(request, request_id))


def reset(token: contextvars.Token) -> None:
_context.reset(token)
@contextmanager
def request_context(request: HTTPServerRequest, request_id: str) -> Iterator:
token = _context.set(_Context(request, request_id))
try:
yield
finally:
_context.reset(token)


def get_request():
Expand All @@ -46,6 +51,16 @@ def set_handler_name(handler_name: str) -> None:
_context.get().handler_name = handler_name


def set_handler(handler: PageHandler) -> None:
context = _context.get()
context.handler_name = repr(handler)
context.handler = handler


def current_handler() -> PageHandler:
return _context.get().handler # type: ignore[return-value]


def get_log_handler() -> DebugBufferedHandler | None:
return _context.get().log_handler

Expand Down
30 changes: 30 additions & 0 deletions tests/test_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@

import pytest

from frontik.app import FrontikApplication
from frontik.dependency_manager import async_deps, build_and_run_sub_graph, dep, execute_page_method_with_dependencies
from frontik.handler import PageHandler
from frontik.request_context import current_handler
from frontik.testing import FrontikTestBase
from tests import FRONTIK_ROOT


class TestPageHandler(PageHandler):
Expand Down Expand Up @@ -187,3 +191,29 @@ async def test_async_deps(self):
handler = AsyncDependencyHandler()
await execute_page_method_with_dependencies(handler, handler.get_page)
assert ['get_page', 'get_session', 'check_session'] == DEP_LOG


def depends_on_current(handler: PageHandler = dep(current_handler)) -> None:
handler.json.put({"from_dep": 1})


class HandlerWithDependencyWhichDependsOnCurrentHandler(PageHandler):
def get_page(self, handler=dep(depends_on_current)):
self.json.put({"from_handler": 2})


class TestDepApplication(FrontikTestBase):
class TestApp(FrontikApplication):
def application_urls(self) -> list[tuple]:
return [
('/inject_current', HandlerWithDependencyWhichDependsOnCurrentHandler),
]

@pytest.fixture(scope='class')
def test_app(self):
return self.TestApp(app='test_app', app_root=FRONTIK_ROOT)

async def test_dependency_uses_current_handler(self):
response = await self.fetch('/inject_current')
assert response.data["from_dep"] == 1
assert response.data["from_handler"] == 2
46 changes: 26 additions & 20 deletions tests/test_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,36 +30,42 @@ def test_generate_trace_id_with_none_request_id(self) -> None:
self.assertIsNotNone(trace_id)

def test_generate_trace_id_with_hex_request_id(self) -> None:
request_context.initialize(HTTPServerRequest(), '163897206709842601f90a070699ac44')
trace_id = self.trace_id_generator.generate_trace_id()
self.assertEqual('0x163897206709842601f90a070699ac44', hex(trace_id))
with request_context.request_context(HTTPServerRequest(), '163897206709842601f90a070699ac44'):
trace_id = self.trace_id_generator.generate_trace_id()
self.assertEqual('0x163897206709842601f90a070699ac44', hex(trace_id))

def test_generate_trace_id_with_no_hex_request_id(self) -> None:
request_context.initialize(HTTPServerRequest(), 'non-hex-string-1234')
trace_id = self.trace_id_generator.generate_trace_id()
self.assertIsNotNone(trace_id)
with request_context.request_context(HTTPServerRequest(), 'non-hex-string-1234'):
trace_id = self.trace_id_generator.generate_trace_id()
self.assertIsNotNone(trace_id)

def test_generate_trace_id_with_no_str_request_id(self) -> None:
request_context.initialize(HTTPServerRequest(), 12345678910) # type: ignore
trace_id = self.trace_id_generator.generate_trace_id()
self.assertIsNotNone(trace_id)
with request_context.request_context(HTTPServerRequest(), 12345678910): # type: ignore
trace_id = self.trace_id_generator.generate_trace_id()
self.assertIsNotNone(trace_id)

def test_generate_trace_id_with_hex_request_id_and_postfix(self) -> None:
request_context.initialize(HTTPServerRequest(), '163897206709842601f90a070699ac44_some_postfix_string')
trace_id = self.trace_id_generator.generate_trace_id()
self.assertEqual('0x163897206709842601f90a070699ac44', hex(trace_id))
with request_context.request_context(
HTTPServerRequest(),
'163897206709842601f90a070699ac44_some_postfix_string',
):
trace_id = self.trace_id_generator.generate_trace_id()
self.assertEqual('0x163897206709842601f90a070699ac44', hex(trace_id))

def test_generate_trace_id_with_no_hex_request_id_in_first_32_characters(self) -> None:
request_context.initialize(HTTPServerRequest(), '16389720670_NOT_HEX_9842601f90a070699ac44_some_postfix_string')
trace_id = self.trace_id_generator.generate_trace_id()
self.assertIsNotNone(trace_id)
self.assertNotEqual('0x16389720670_NOT_HEX_9842601f90a0', hex(trace_id))
with request_context.request_context(
HTTPServerRequest(),
'16389720670_NOT_HEX_9842601f90a070699ac44_some_postfix_string',
):
trace_id = self.trace_id_generator.generate_trace_id()
self.assertIsNotNone(trace_id)
self.assertNotEqual('0x16389720670_NOT_HEX_9842601f90a0', hex(trace_id))

def test_generate_trace_id_with_request_id_len_less_32_characters(self) -> None:
request_context.initialize(HTTPServerRequest(), '163897206')
trace_id = self.trace_id_generator.generate_trace_id()
self.assertIsNotNone(trace_id)
self.assertNotEqual('0x163897206', hex(trace_id))
with request_context.request_context(HTTPServerRequest(), '163897206'):
trace_id = self.trace_id_generator.generate_trace_id()
self.assertIsNotNone(trace_id)
self.assertNotEqual('0x163897206', hex(trace_id))

def test_get_netloc(self) -> None:
self.assertEqual('balancer:7000', get_netloc('balancer:7000/xml/get-article/'))
Expand Down

0 comments on commit 706af53

Please sign in to comment.