Skip to content

Commit

Permalink
Merge pull request #314 from kalaspuff/websockets
Browse files Browse the repository at this point in the history
Shortcutting for using websockets as a decorator in the service class
  • Loading branch information
kalaspuff authored Jan 27, 2018
2 parents 53ab046 + 1d549b8 commit 8da253f
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 22 deletions.
7 changes: 4 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
tomodachi - a lightweight microservices framework with asyncio
``tomodachi`` - a lightweight microservices framework with asyncio
==============================================================
.. image:: https://travis-ci.org/kalaspuff/tomodachi.svg?branch=master
:target: https://travis-ci.org/kalaspuff/tomodachi
Expand All @@ -10,7 +10,8 @@ tomodachi - a lightweight microservices framework with asyncio
:target: https://pypi.python.org/pypi/tomodachi

Python 3 microservice framework using asyncio (async / await) with HTTP,
RabbitMQ / AMQP and AWS SNS+SQS support for event bus based communication.
websockets, RabbitMQ / AMQP and AWS SNS+SQS support for event bus based
communication.

Tomodachi is a tiny framework designed to build fast microservices listening on
HTTP or communicating over event driven message buses like RabbitMQ, AMQP,
Expand Down Expand Up @@ -94,7 +95,7 @@ Run the service 😎
< Server: tomodachi
< Content-Length: 9
< Date: Mon, 02 Oct 2017 13:38:02 GMT
id = 1234
id = 1234
Requirements 👍
Expand Down
19 changes: 18 additions & 1 deletion examples/http_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import os
import asyncio
import tomodachi
from typing import Tuple, Callable
from aiohttp import web
from tomodachi.discovery.dummy_registry import DummyRegistry
from tomodachi.protocol.json_base import JsonBase
from tomodachi.transport.http import http, http_error, http_static, Response
from tomodachi.transport.http import http, http_error, http_static, websocket, Response


@tomodachi.service
Expand Down Expand Up @@ -43,6 +44,22 @@ async def static_files(self) -> None:
# This function is actually never called by accessing the /static/ URL:s.
pass

@websocket(r'/websocket/?')
async def websocket_connection(self, websocket: web.WebSocketResponse) -> Tuple[Callable, Callable]:
# Called when a websocket client is connected
self.logger.info('websocket client connected')

async def _receive(data) -> None:
# Called when the websocket receives data
self.logger.info('websocket data received: {}'.format(data))
await websocket.send_str('response')

async def _close() -> None:
# Called when the websocket is closed by the other end
self.logger.info('websocket closed')

return _receive, _close

@http_error(status_code=404)
async def error_404(self, request: web.Request) -> str:
return 'error 404'
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def read(f: str) -> str:
setup(name='tomodachi',
version=tomodachi.__version__,
description=('Python 3 microservice library / framework using asyncio with HTTP, '
'RabbitMQ / AMQP and AWS SNS+SQS support.'),
'websockets, RabbitMQ / AMQP and AWS SNS+SQS support.'),
long_description='\n\n'.join((read('README.rst'), read('CHANGES.rst'))),
classifiers=classifiers,
author='Carl Oscar Aaro',
Expand All @@ -55,7 +55,7 @@ def read(f: str) -> str:
},
install_requires=install_requires,
keywords=('tomodachi, microservice, microservices, framework, library, asyncio, '
'aws, sns, sqs, amqp, rabbitmq, http, easy, fast, python 3'),
'aws, sns, sqs, amqp, rabbitmq, http, websockets, easy, fast, python 3'),
zip_safe=False,
packages=find_packages(),
platforms='any',
Expand Down
17 changes: 15 additions & 2 deletions tests/services/http_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import os
import signal
import tomodachi
from typing import Any, Dict, Tuple # noqa
from typing import Any, Dict, Tuple, Callable # noqa
from aiohttp import web
from tomodachi.transport.http import http, http_error, http_static, Response
from tomodachi.transport.http import http, http_error, http_static, websocket, Response
from tomodachi.discovery.dummy_registry import DummyRegistry


Expand All @@ -22,6 +22,8 @@ class HttpService(object):
uuid = None
closer = asyncio.Future() # type: Any
slow_request = False
websocket_connected = False
websocket_received_data = None

@http('GET', r'/test/?')
async def test(self, request: web.Request) -> str:
Expand Down Expand Up @@ -137,6 +139,17 @@ async def static_files_filename_existing(self) -> None:
async def test_404(self, request: web.Request) -> str:
return 'test 404'

@websocket(r'/websocket-simple')
async def websocket_simple(self, websocket: web.WebSocketResponse) -> None:
self.websocket_connected = True

@websocket(r'/websocket-data')
async def websocket_data(self, websocket: web.WebSocketResponse) -> Callable:
async def _receive(data):
self.websocket_received_data = data

return _receive

async def _started_service(self) -> None:
async def _async() -> None:
async def sleep_and_kill() -> None:
Expand Down
14 changes: 14 additions & 0 deletions tests/test_http_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,20 @@ async def _async(loop: Any) -> None:
assert response is not None
assert response.status == 404

assert instance.websocket_connected is False
async with aiohttp.ClientSession(loop=loop) as client:
async with client.ws_connect('http://127.0.0.1:{}/websocket-simple'.format(port)) as ws:
await ws.close()
assert instance.websocket_connected is True

async with aiohttp.ClientSession(loop=loop) as client:
async with client.ws_connect('http://127.0.0.1:{}/websocket-data'.format(port)) as ws:
data = '9e2546ef-7fe1-4f94-a3fc-5dc85a771a17'
assert instance.websocket_received_data != data
await ws.send_str(data)
await ws.close()
assert instance.websocket_received_data == data

loop.run_until_complete(_async(loop))
instance.stop_service()
loop.run_until_complete(future)
Expand Down
97 changes: 83 additions & 14 deletions tomodachi/transport/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
import os
import pathlib
import inspect
import uuid
from logging.handlers import WatchedFileHandler
from typing import Any, Dict, List, Tuple, Union, Optional, Callable, SupportsInt # noqa
try:
from typing import Awaitable
except ImportError:
from collections.abc import Awaitable
from multidict import CIMultiDict, CIMultiDictProxy
from aiohttp import web, web_server, web_protocol, web_urldispatcher, hdrs
from aiohttp import web, web_server, web_protocol, web_urldispatcher, hdrs, WSMsgType
from aiohttp.web_fileresponse import FileResponse
from aiohttp.http import HttpVersion
from aiohttp.helpers import BasicAuth
Expand Down Expand Up @@ -318,6 +319,58 @@ async def handler(request: web.Request) -> web.Response:
start_func = cls.start_server(obj, context)
return (await start_func) if start_func else None

async def websocket_handler(cls: Any, obj: Any, context: Dict, func: Any, url: str) -> Any:
pattern = r'^{}$'.format(re.sub(r'\$$', '', re.sub(r'^\^?(.*)$', r'\1', url)))
compiled_pattern = re.compile(pattern)

async def _func(obj: Any, request: web.Request) -> None:
websocket = web.WebSocketResponse()
await websocket.prepare(request)

request.is_websocket = True
request.websocket_uuid = str(uuid.uuid4())

logging.getLogger('transport.http').info('[websocket] {} {} "OPEN {}{}" {} "{}" {}'.format(
request.request_ip,
'"{}"'.format(request.auth.login.replace('"', '')) if request.auth and getattr(request.auth, 'login', None) else '-',
request.path,
'?{}'.format(request.query_string) if request.query_string else '',
request.websocket_uuid,
request.headers.get('User-Agent', '').replace('"', ''),
'-'
))

result = compiled_pattern.match(request.path)
values = inspect.getfullargspec(func)
kwargs = {k: values.defaults[i] for i, k in enumerate(values.args[len(values.args) - len(values.defaults):])} if values.defaults else {}
if result:
for k, v in result.groupdict().items():
kwargs[k] = v

routine = func(*(obj, websocket,), **kwargs)
callback_functions = (await routine) if isinstance(routine, Awaitable) else routine # type: Optional[Union[Tuple, Callable]]
_receive_func = None
_close_func = None

if callback_functions and isinstance(callback_functions, tuple):
try:
_receive_func, _close_func = callback_functions
except ValueError:
_receive_func, = callback_functions
elif callback_functions:
_receive_func = callback_functions

try:
async for message in websocket:
if message.type == WSMsgType.TEXT:
if _receive_func:
await _receive_func(message.data)
finally:
if _close_func:
await _close_func()

return await cls.request_handler(cls, obj, context, _func, 'GET', url)

async def start_server(obj: Any, context: Dict) -> Optional[Callable]:
if context.get('_http_server_started'):
return None
Expand Down Expand Up @@ -365,6 +418,7 @@ async def func() -> web.Response:
request.request_ip = request_ip

request.auth = None
request.is_websocket = False
if request.headers.get('Authorization'):
try:
request.auth = BasicAuth.decode(request.headers.get('Authorization'))
Expand Down Expand Up @@ -407,19 +461,31 @@ async def func() -> web.Response:
version_string = None
if isinstance(request.version, HttpVersion):
version_string = 'HTTP/{}.{}'.format(request.version.major, request.version.minor)
logging.getLogger('transport.http').info('[http] [{}] {} {} "{} {}{}{}" {} {} "{}" {}'.format(
response.status if response else 500,
request.request_ip,
'"{}"'.format(request.auth.login.replace('"', '')) if request.auth and getattr(request.auth, 'login', None) else '-',
request.method,
request.path,
'?{}'.format(request.query_string) if request.query_string else '',
' {}'.format(version_string) if version_string else '',
response.content_length if response and response.content_length is not None else '-',
request.content_length if request.content_length is not None else '-',
request.headers.get('User-Agent', '').replace('"', ''),
'{0:.5f}s'.format(round(request_time, 5))
))

if not request.is_websocket:
logging.getLogger('transport.http').info('[http] [{}] {} {} "{} {}{}{}" {} {} "{}" {}'.format(
response.status if response else 500,
request.request_ip,
'"{}"'.format(request.auth.login.replace('"', '')) if request.auth and getattr(request.auth, 'login', None) else '-',
request.method,
request.path,
'?{}'.format(request.query_string) if request.query_string else '',
' {}'.format(version_string) if version_string else '',
response.content_length if response and response.content_length is not None else '-',
request.content_length if request.content_length is not None else '-',
request.headers.get('User-Agent', '').replace('"', ''),
'{0:.5f}s'.format(round(request_time, 5))
))
else:
logging.getLogger('transport.http').info('[websocket] {} {} "CLOSE {}{}" {} "{}" {}'.format(
request.request_ip,
'"{}"'.format(request.auth.login.replace('"', '')) if request.auth and getattr(request.auth, 'login', None) else '-',
request.path,
'?{}'.format(request.query_string) if request.query_string else '',
request.websocket_uuid,
request.headers.get('User-Agent', '').replace('"', ''),
'{0:.5f}s'.format(round(request_time, 5))
))

return response

Expand Down Expand Up @@ -477,3 +543,6 @@ async def stop_service(*args: Any, **kwargs: Any) -> None:
http = HttpTransport.decorator(HttpTransport.request_handler)
http_error = HttpTransport.decorator(HttpTransport.error_handler)
http_static = HttpTransport.decorator(HttpTransport.static_request_handler)

websocket = HttpTransport.decorator(HttpTransport.websocket_handler)
ws = HttpTransport.decorator(HttpTransport.websocket_handler)

0 comments on commit 8da253f

Please sign in to comment.