From 1d549b8ccb5e797b9bcecd09cc24d558a4c799a8 Mon Sep 17 00:00:00 2001 From: Carl Oscar Aaro Date: Fri, 26 Jan 2018 18:58:13 +0100 Subject: [PATCH] Shortcutting for using websockets as a decorator in the service class --- README.rst | 7 +-- examples/http_service.py | 19 ++++++- setup.py | 4 +- tests/services/http_service.py | 17 +++++- tests/test_http_service.py | 14 +++++ tomodachi/transport/http.py | 97 +++++++++++++++++++++++++++++----- 6 files changed, 136 insertions(+), 22 deletions(-) diff --git a/README.rst b/README.rst index 52df2a81a..7ff832c00 100644 --- a/README.rst +++ b/README.rst @@ -1,4 +1,4 @@ -tomodachi - a lightweight microservices framework with asyncio +``tomodachi`` - a lightweight microservices framework with asyncio ============================================================== .. image:: https://travis-ci.org/kalaspuff/tomodachi.svg?branch=master :target: https://travis-ci.org/kalaspuff/tomodachi @@ -10,7 +10,8 @@ tomodachi - a lightweight microservices framework with asyncio :target: https://pypi.python.org/pypi/tomodachi Python 3 microservice framework using asyncio (async / await) with HTTP, -RabbitMQ / AMQP and AWS SNS+SQS support for event bus based communication. +websockets, RabbitMQ / AMQP and AWS SNS+SQS support for event bus based +communication. Tomodachi is a tiny framework designed to build fast microservices listening on HTTP or communicating over event driven message buses like RabbitMQ, AMQP, @@ -94,7 +95,7 @@ Run the service 😎 < Server: tomodachi < Content-Length: 9 < Date: Mon, 02 Oct 2017 13:38:02 GMT - id = 1234 + id = 1234 Requirements 👍 diff --git a/examples/http_service.py b/examples/http_service.py index 0fdae0a8e..7dc33d903 100644 --- a/examples/http_service.py +++ b/examples/http_service.py @@ -2,10 +2,11 @@ import os import asyncio import tomodachi +from typing import Tuple, Callable from aiohttp import web from tomodachi.discovery.dummy_registry import DummyRegistry from tomodachi.protocol.json_base import JsonBase -from tomodachi.transport.http import http, http_error, http_static, Response +from tomodachi.transport.http import http, http_error, http_static, websocket, Response @tomodachi.service @@ -43,6 +44,22 @@ async def static_files(self) -> None: # This function is actually never called by accessing the /static/ URL:s. pass + @websocket(r'/websocket/?') + async def websocket_connection(self, websocket: web.WebSocketResponse) -> Tuple[Callable, Callable]: + # Called when a websocket client is connected + self.logger.info('websocket client connected') + + async def _receive(data) -> None: + # Called when the websocket receives data + self.logger.info('websocket data received: {}'.format(data)) + await websocket.send_str('response') + + async def _close() -> None: + # Called when the websocket is closed by the other end + self.logger.info('websocket closed') + + return _receive, _close + @http_error(status_code=404) async def error_404(self, request: web.Request) -> str: return 'error 404' diff --git a/setup.py b/setup.py index e14988c3c..6221ffc3c 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,7 @@ def read(f: str) -> str: setup(name='tomodachi', version=tomodachi.__version__, description=('Python 3 microservice library / framework using asyncio with HTTP, ' - 'RabbitMQ / AMQP and AWS SNS+SQS support.'), + 'websockets, RabbitMQ / AMQP and AWS SNS+SQS support.'), long_description='\n\n'.join((read('README.rst'), read('CHANGES.rst'))), classifiers=classifiers, author='Carl Oscar Aaro', @@ -55,7 +55,7 @@ def read(f: str) -> str: }, install_requires=install_requires, keywords=('tomodachi, microservice, microservices, framework, library, asyncio, ' - 'aws, sns, sqs, amqp, rabbitmq, http, easy, fast, python 3'), + 'aws, sns, sqs, amqp, rabbitmq, http, websockets, easy, fast, python 3'), zip_safe=False, packages=find_packages(), platforms='any', diff --git a/tests/services/http_service.py b/tests/services/http_service.py index 935b55fc9..96bb611cc 100644 --- a/tests/services/http_service.py +++ b/tests/services/http_service.py @@ -2,9 +2,9 @@ import os import signal import tomodachi -from typing import Any, Dict, Tuple # noqa +from typing import Any, Dict, Tuple, Callable # noqa from aiohttp import web -from tomodachi.transport.http import http, http_error, http_static, Response +from tomodachi.transport.http import http, http_error, http_static, websocket, Response from tomodachi.discovery.dummy_registry import DummyRegistry @@ -22,6 +22,8 @@ class HttpService(object): uuid = None closer = asyncio.Future() # type: Any slow_request = False + websocket_connected = False + websocket_received_data = None @http('GET', r'/test/?') async def test(self, request: web.Request) -> str: @@ -137,6 +139,17 @@ async def static_files_filename_existing(self) -> None: async def test_404(self, request: web.Request) -> str: return 'test 404' + @websocket(r'/websocket-simple') + async def websocket_simple(self, websocket: web.WebSocketResponse) -> None: + self.websocket_connected = True + + @websocket(r'/websocket-data') + async def websocket_data(self, websocket: web.WebSocketResponse) -> Callable: + async def _receive(data): + self.websocket_received_data = data + + return _receive + async def _started_service(self) -> None: async def _async() -> None: async def sleep_and_kill() -> None: diff --git a/tests/test_http_service.py b/tests/test_http_service.py index a57795d78..a17d5f956 100644 --- a/tests/test_http_service.py +++ b/tests/test_http_service.py @@ -256,6 +256,20 @@ async def _async(loop: Any) -> None: assert response is not None assert response.status == 404 + assert instance.websocket_connected is False + async with aiohttp.ClientSession(loop=loop) as client: + async with client.ws_connect('http://127.0.0.1:{}/websocket-simple'.format(port)) as ws: + await ws.close() + assert instance.websocket_connected is True + + async with aiohttp.ClientSession(loop=loop) as client: + async with client.ws_connect('http://127.0.0.1:{}/websocket-data'.format(port)) as ws: + data = '9e2546ef-7fe1-4f94-a3fc-5dc85a771a17' + assert instance.websocket_received_data != data + await ws.send_str(data) + await ws.close() + assert instance.websocket_received_data == data + loop.run_until_complete(_async(loop)) instance.stop_service() loop.run_until_complete(future) diff --git a/tomodachi/transport/http.py b/tomodachi/transport/http.py index 5f0cad8f8..ccda25d83 100644 --- a/tomodachi/transport/http.py +++ b/tomodachi/transport/http.py @@ -7,6 +7,7 @@ import os import pathlib import inspect +import uuid from logging.handlers import WatchedFileHandler from typing import Any, Dict, List, Tuple, Union, Optional, Callable, SupportsInt # noqa try: @@ -14,7 +15,7 @@ except ImportError: from collections.abc import Awaitable from multidict import CIMultiDict, CIMultiDictProxy -from aiohttp import web, web_server, web_protocol, web_urldispatcher, hdrs +from aiohttp import web, web_server, web_protocol, web_urldispatcher, hdrs, WSMsgType from aiohttp.web_fileresponse import FileResponse from aiohttp.http import HttpVersion from aiohttp.helpers import BasicAuth @@ -318,6 +319,58 @@ async def handler(request: web.Request) -> web.Response: start_func = cls.start_server(obj, context) return (await start_func) if start_func else None + async def websocket_handler(cls: Any, obj: Any, context: Dict, func: Any, url: str) -> Any: + pattern = r'^{}$'.format(re.sub(r'\$$', '', re.sub(r'^\^?(.*)$', r'\1', url))) + compiled_pattern = re.compile(pattern) + + async def _func(obj: Any, request: web.Request) -> None: + websocket = web.WebSocketResponse() + await websocket.prepare(request) + + request.is_websocket = True + request.websocket_uuid = str(uuid.uuid4()) + + logging.getLogger('transport.http').info('[websocket] {} {} "OPEN {}{}" {} "{}" {}'.format( + request.request_ip, + '"{}"'.format(request.auth.login.replace('"', '')) if request.auth and getattr(request.auth, 'login', None) else '-', + request.path, + '?{}'.format(request.query_string) if request.query_string else '', + request.websocket_uuid, + request.headers.get('User-Agent', '').replace('"', ''), + '-' + )) + + result = compiled_pattern.match(request.path) + values = inspect.getfullargspec(func) + kwargs = {k: values.defaults[i] for i, k in enumerate(values.args[len(values.args) - len(values.defaults):])} if values.defaults else {} + if result: + for k, v in result.groupdict().items(): + kwargs[k] = v + + routine = func(*(obj, websocket,), **kwargs) + callback_functions = (await routine) if isinstance(routine, Awaitable) else routine # type: Optional[Union[Tuple, Callable]] + _receive_func = None + _close_func = None + + if callback_functions and isinstance(callback_functions, tuple): + try: + _receive_func, _close_func = callback_functions + except ValueError: + _receive_func, = callback_functions + elif callback_functions: + _receive_func = callback_functions + + try: + async for message in websocket: + if message.type == WSMsgType.TEXT: + if _receive_func: + await _receive_func(message.data) + finally: + if _close_func: + await _close_func() + + return await cls.request_handler(cls, obj, context, _func, 'GET', url) + async def start_server(obj: Any, context: Dict) -> Optional[Callable]: if context.get('_http_server_started'): return None @@ -365,6 +418,7 @@ async def func() -> web.Response: request.request_ip = request_ip request.auth = None + request.is_websocket = False if request.headers.get('Authorization'): try: request.auth = BasicAuth.decode(request.headers.get('Authorization')) @@ -407,19 +461,31 @@ async def func() -> web.Response: version_string = None if isinstance(request.version, HttpVersion): version_string = 'HTTP/{}.{}'.format(request.version.major, request.version.minor) - logging.getLogger('transport.http').info('[http] [{}] {} {} "{} {}{}{}" {} {} "{}" {}'.format( - response.status if response else 500, - request.request_ip, - '"{}"'.format(request.auth.login.replace('"', '')) if request.auth and getattr(request.auth, 'login', None) else '-', - request.method, - request.path, - '?{}'.format(request.query_string) if request.query_string else '', - ' {}'.format(version_string) if version_string else '', - response.content_length if response and response.content_length is not None else '-', - request.content_length if request.content_length is not None else '-', - request.headers.get('User-Agent', '').replace('"', ''), - '{0:.5f}s'.format(round(request_time, 5)) - )) + + if not request.is_websocket: + logging.getLogger('transport.http').info('[http] [{}] {} {} "{} {}{}{}" {} {} "{}" {}'.format( + response.status if response else 500, + request.request_ip, + '"{}"'.format(request.auth.login.replace('"', '')) if request.auth and getattr(request.auth, 'login', None) else '-', + request.method, + request.path, + '?{}'.format(request.query_string) if request.query_string else '', + ' {}'.format(version_string) if version_string else '', + response.content_length if response and response.content_length is not None else '-', + request.content_length if request.content_length is not None else '-', + request.headers.get('User-Agent', '').replace('"', ''), + '{0:.5f}s'.format(round(request_time, 5)) + )) + else: + logging.getLogger('transport.http').info('[websocket] {} {} "CLOSE {}{}" {} "{}" {}'.format( + request.request_ip, + '"{}"'.format(request.auth.login.replace('"', '')) if request.auth and getattr(request.auth, 'login', None) else '-', + request.path, + '?{}'.format(request.query_string) if request.query_string else '', + request.websocket_uuid, + request.headers.get('User-Agent', '').replace('"', ''), + '{0:.5f}s'.format(round(request_time, 5)) + )) return response @@ -477,3 +543,6 @@ async def stop_service(*args: Any, **kwargs: Any) -> None: http = HttpTransport.decorator(HttpTransport.request_handler) http_error = HttpTransport.decorator(HttpTransport.error_handler) http_static = HttpTransport.decorator(HttpTransport.static_request_handler) + +websocket = HttpTransport.decorator(HttpTransport.websocket_handler) +ws = HttpTransport.decorator(HttpTransport.websocket_handler)