Skip to content

Commit

Permalink
Merge pull request #966 from kalaspuff/feature/websocket-request-argu…
Browse files Browse the repository at this point in the history
…ment

Added optional third argument to websocket handler to pass on request object for headers parsing, etc.
  • Loading branch information
kalaspuff authored Aug 22, 2019
2 parents 334b442 + 7bd7046 commit 55a16bb
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ HTTP endpoints:
Sets up an **HTTP endpoint for static content** available as ``GET`` / ``HEAD`` from the ``path`` on disk on the base regexp ``url``.

``@tomodachi.websocket(url)``
Sets up a **websocket endpoint** on the regexp ``url``. The invoked function is called upon websocket connection and should return a two value tuple containing callables for a function receiving frames (first callable) and a function called on websocket close (second callable).
Sets up a **websocket endpoint** on the regexp ``url``. The invoked function is called upon websocket connection and should return a two value tuple containing callables for a function receiving frames (first callable) and a function called on websocket close (second callable). The passed arguments to the function beside the class object is first the ``websocket`` response connection which can be used to send frames to the client, and optionally also the ``request`` object.

``@tomodachi.http_error(status_code)``
A function which will be called if the **HTTP request would result in a 4XX** ``status_code``. You may use this for example to set up a custom handler on "404 Not Found" or "403 Forbidden" responses.
Expand Down
5 changes: 5 additions & 0 deletions tests/services/http_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class HttpService(tomodachi.Service):
function_triggered = False
websocket_connected = False
websocket_received_data = None
websocket_header = None

@http('GET', r'/test/?')
async def test(self, request: web.Request) -> str:
Expand Down Expand Up @@ -166,6 +167,10 @@ async def test_404(self, request: web.Request) -> str:
async def websocket_simple(self, websocket: web.WebSocketResponse) -> None:
self.websocket_connected = True

@websocket(r'/websocket-header')
async def websocket_with_header(self, websocket: web.WebSocketResponse, request: web.Request) -> None:
self.websocket_header = request.headers.get('User-Agent')

@websocket(r'/websocket-data')
async def websocket_data(self, websocket: web.WebSocketResponse) -> Callable:
async def _receive(data: Union[str, bytes]) -> None:
Expand Down
8 changes: 8 additions & 0 deletions tests/test_http_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,14 @@ async def _async(loop: Any) -> None:
await ws.close()
assert instance.websocket_connected is True

assert instance.websocket_header is None
async with aiohttp.ClientSession(loop=loop) as client:
async with client.ws_connect('http://127.0.0.1:{}/websocket-header'.format(port)) as ws:
await ws.close()
assert instance.websocket_header is not None
assert 'Python' in instance.websocket_header
assert 'aiohttp' in instance.websocket_header

async with aiohttp.ClientSession(loop=loop) as client:
async with client.ws_connect('http://127.0.0.1:{}/websocket-data'.format(port)) as ws:
data = '9e2546ef-7fe1-4f94-a3fc-5dc85a771a17'
Expand Down
6 changes: 6 additions & 0 deletions tomodachi/transport/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,12 @@ async def _func(obj: Any, request: web.Request, *a: Any, **kw: Any) -> None:
for k, v in result.groupdict().items():
kwargs[k] = v

if len(values.args) - (len(values.defaults) if values.defaults else 0) >= 3:
# If the function takes a third required argument the value will be filled with the request object
a = a + (request,)
if 'request' in values.args and (len(values.args) - (len(values.defaults) if values.defaults else 0) < 3 or values.args[2] != 'request'):
kwargs['request'] = request

try:
routine = func(*(obj, websocket, *a), **merge_dicts(kwargs, kw))
callback_functions = (await routine) if isinstance(routine, Awaitable) else routine # type: Optional[Union[Tuple, Callable]]
Expand Down

0 comments on commit 55a16bb

Please sign in to comment.