Skip to content

Commit

Permalink
Refactored middleware functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
kalaspuff committed Mar 7, 2019
1 parent aad4ae5 commit cb9d155
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 73 deletions.
2 changes: 1 addition & 1 deletion examples/basic_examples/amqp_middleware_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tomodachi.protocol import JsonBase


async def middleware_function(func: Callable, service: Any, message: Any, routing_key: str, *args: Any, **kwargs: Any) -> Any:
async def middleware_function(func: Callable, service: Any, message: Any, routing_key: str, context: Dict, *args: Any, **kwargs: Any) -> Any:
# Functionality before function is called
service.log('middleware before')

Expand Down
2 changes: 1 addition & 1 deletion examples/basic_examples/aws_sns_sqs_middleware_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tomodachi.protocol import JsonBase


async def middleware_function(func: Callable, service: Any, message: Any, topic: str, *args: Any, **kwargs: Any) -> Any:
async def middleware_function(func: Callable, service: Any, message: Any, topic: str, context: Dict, *args: Any, **kwargs: Any) -> Any:
# Functionality before function is called
service.log('middleware before')

Expand Down
4 changes: 2 additions & 2 deletions examples/basic_examples/http_middleware_service.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os
import asyncio
import tomodachi
from typing import Tuple, Callable, Union, Any
from typing import Tuple, Callable, Union, Any, Dict
from aiohttp import web
from tomodachi import http, http_error, http_static, websocket, HttpResponse
from tomodachi.discovery import DummyRegistry


async def middleware_function(func: Callable, service: Any, request: web.Request, *args: Any, **kwargs: Any) -> Any:
async def middleware_function(func: Callable, service: Any, request: web.Request, context: Dict, *args: Any, **kwargs: Any) -> Any:
# Functionality before function is called
service.log('middleware before')

Expand Down
2 changes: 1 addition & 1 deletion tests/services/http_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tomodachi.discovery.dummy_registry import DummyRegistry


async def middleware_function(func: Callable, service: Any, request: web.Request) -> Any:
async def middleware_function(func: Callable, service: Any, request: web.Request, context: Dict, *args: Any, **kwargs: Any) -> Any:
if request.headers.get('X-Use-Middleware') == 'Set':
service.middleware_called = True

Expand Down
28 changes: 28 additions & 0 deletions tomodachi/helpers/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import functools
import inspect
from typing import Callable, List, Any


async def execute_middlewares(func: Callable, routine_func: Callable, middlewares: List, *args: Any) -> Any:
if middlewares:
middleware_context = {}
async def middleware_bubble(idx: int = 0, *ma: Any, **mkw: Any) -> Any:
@functools.wraps(func)
async def _func(*a: Any, **kw: Any) -> Any:
return await middleware_bubble(idx + 1, *a, **kw)

if middlewares and len(middlewares) <= idx + 1:
_func = routine_func

middleware = middlewares[idx] # type: Callable

arg_len = len(inspect.getfullargspec(middleware).args) - (len(inspect.getfullargspec(middleware).defaults) if inspect.getfullargspec(middleware).defaults else 0)
middleware_arguments = [_func, *args, middleware_context][0:arg_len]

return await middleware(*middleware_arguments, *ma, **mkw)

return_value = await middleware_bubble()
else:
return_value = await routine_func()

return return_value
19 changes: 2 additions & 17 deletions tomodachi/transport/amqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Any, Dict, Union, Optional, Callable, Match, Awaitable, List
from tomodachi.invoker import Invoker
from tomodachi.helpers.dict import merge_dicts
from tomodachi.helpers.middleware import execute_middlewares

MESSAGE_PROTOCOL_DEFAULT = '2594418c-5771-454a-a7f9-8f83ae82812a'

Expand Down Expand Up @@ -222,23 +223,7 @@ async def routine_func(*a: Any, **kw: Any) -> Any:
await cls.channel.basic_client_ack(delivery_tag)
return return_value

middlewares = context.get('message_middleware', []) # type: List[Callable]
if middlewares:
async def middleware_bubble(idx: int = 0, *ma: Any, **mkw: Any) -> Any:
@functools.wraps(func)
async def _func(*a: Any, **kw: Any) -> Any:
return await middleware_bubble(idx + 1, *a, **kw)

if middlewares and len(middlewares) <= idx + 1:
_func = routine_func

middleware = middlewares[idx] # type: Callable
return await middleware(_func, obj, message, routing_key, *ma, **mkw)

return_value = await middleware_bubble()
else:
return_value = await routine_func()

return_value = await execute_middlewares(func, routine_func, context.get('message_middleware', []), *(obj, message, routing_key))
return return_value

exchange_name = exchange_name or context.get('options', {}).get('amqp', {}).get('exchange_name', 'amq.topic')
Expand Down
19 changes: 2 additions & 17 deletions tomodachi/transport/aws_sns_sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Any, Dict, Union, Optional, Callable, List, Tuple, Match, Awaitable
from tomodachi.invoker import Invoker
from tomodachi.helpers.dict import merge_dicts
from tomodachi.helpers.middleware import execute_middlewares

DRAIN_MESSAGE_PAYLOAD = '__TOMODACHI_DRAIN__cdab4416-1727-4603-87c9-0ff8dddf1f22__'
MESSAGE_PROTOCOL_DEFAULT = 'e6fb6007-cf15-4cfd-af2e-1d1683374e70'
Expand Down Expand Up @@ -216,23 +217,7 @@ async def routine_func(*a: Any, **kw: Any) -> Any:
await cls.delete_message(cls, receipt_handle, queue_url, context)
return return_value

middlewares = context.get('message_middleware', []) # type: List[Callable]
if middlewares:
async def middleware_bubble(idx: int = 0, *ma: Any, **mkw: Any) -> Any:
@functools.wraps(func)
async def _func(*a: Any, **kw: Any) -> Any:
return await middleware_bubble(idx + 1, *a, **kw)

if middlewares and len(middlewares) <= idx + 1:
_func = routine_func

middleware = middlewares[idx] # type: Callable
return await middleware(_func, obj, message, message_topic, *ma, **mkw)

return_value = await middleware_bubble()
else:
return_value = await routine_func()

return_value = await execute_middlewares(func, routine_func, context.get('message_middleware', []), *(obj, message, topic))
return return_value

context['_aws_sns_sqs_subscribers'] = context.get('_aws_sns_sqs_subscribers', [])
Expand Down
37 changes: 3 additions & 34 deletions tomodachi/transport/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from aiohttp.streams import EofStream
from tomodachi.invoker import Invoker
from tomodachi.helpers.dict import merge_dicts
from tomodachi.helpers.middleware import execute_middlewares


class HttpException(Exception):
Expand Down Expand Up @@ -246,23 +247,7 @@ async def routine_func(*a: Any, **kw: Any) -> Union[str, bytes, Dict, List, Tupl
if pre_handler_func:
await pre_handler_func(obj, request)

middlewares = context.get('http_middleware', []) # type: List[Callable]
if middlewares:
async def middleware_bubble(idx: int = 0, *ma: Any, **mkw: Any) -> Any:
@functools.wraps(func)
async def _func(*a: Any, **kw: Any) -> Any:
return await middleware_bubble(idx + 1, *a, **kw)

if middlewares and len(middlewares) <= idx + 1:
_func = routine_func

middleware = middlewares[idx] # type: Callable
return await middleware(_func, obj, request, *ma, **mkw)

return_value = await middleware_bubble()
else:
return_value = await routine_func()

return_value = await execute_middlewares(func, routine_func, context.get('http_middleware', []), *(obj, request))
response = await resolve_response(return_value, request=request, context=context, default_content_type=default_content_type, default_charset=default_charset)
return response

Expand Down Expand Up @@ -340,23 +325,7 @@ async def routine_func(*a: Any, **kw: Any) -> Union[str, bytes, Dict, List, Tupl
return_value = (await routine) if isinstance(routine, Awaitable) else routine # type: Union[str, bytes, Dict, List, Tuple, web.Response, Response]
return return_value

middlewares = context.get('http_middleware', []) # type: List[Callable]
if int(status_code) in (404,) and middlewares:
async def middleware_bubble(idx: int = 0, *ma: Any, **mkw: Any) -> Any:
@functools.wraps(func)
async def _func(*a: Any, **kw: Any) -> Any:
return await middleware_bubble(idx + 1, *a, **kw)

if middlewares and len(middlewares) <= idx + 1:
_func = routine_func

middleware = middlewares[idx] # type: Callable
return await middleware(_func, obj, request, *ma, **mkw)

return_value = await middleware_bubble()
else:
return_value = await routine_func()

return_value = await execute_middlewares(func, routine_func, context.get('http_middleware', []), *(obj, request))
response = await resolve_response(return_value, request=request, context=context, status_code=status_code, default_content_type=default_content_type, default_charset=default_charset)
return response

Expand Down

0 comments on commit cb9d155

Please sign in to comment.