Skip to content

Commit

Permalink
Close websockets on service exit + handle errors on close and receive…
Browse files Browse the repository at this point in the history
… methods
  • Loading branch information
kalaspuff committed Mar 2, 2018
1 parent 1c2307c commit d182f2c
Showing 1 changed file with 64 additions and 4 deletions.
68 changes: 64 additions & 4 deletions tomodachi/transport/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,9 @@ async def _func(obj: Any, request: web.Request) -> None:
websocket = web.WebSocketResponse()
await websocket.prepare(request)

context['_http_open_websockets'] = context.get('_http_open_websockets', [])
context['_http_open_websockets'].append(websocket)

request.is_websocket = True
request.websocket_uuid = str(uuid.uuid4())

Expand All @@ -378,8 +381,36 @@ async def _func(obj: Any, request: web.Request) -> None:
for k, v in result.groupdict().items():
kwargs[k] = v

routine = func(*(obj, websocket,), **kwargs)
callback_functions = (await routine) if isinstance(routine, Awaitable) else routine # type: Optional[Union[Tuple, Callable]]
try:
routine = func(*(obj, websocket,), **kwargs)
callback_functions = (await routine) if isinstance(routine, Awaitable) else routine # type: Optional[Union[Tuple, Callable]]
except Exception as e:
if not context.get('log_level') or context.get('log_level') in ['DEBUG']:
traceback.print_exception(e.__class__, e, e.__traceback__)
try:
await websocket.close()
except:
pass

try:
context['_http_open_websockets'].remove(websocket)
except:
pass

logging.getLogger('transport.http').info('[{}] {} {} "{} {}{}" {} "{}" {}'.format(
RequestHandler.colorize_status('websocket', 500),
request.request_ip,
'"{}"'.format(request.auth.login.replace('"', '')) if request.auth and getattr(request.auth, 'login', None) else '-',
RequestHandler.colorize_status('ERROR', 500),
request.path,
'?{}'.format(request.query_string) if request.query_string else '',
request.websocket_uuid,
request.headers.get('User-Agent', '').replace('"', ''),
'-'
))

return

_receive_func = None
_close_func = None

Expand All @@ -395,10 +426,33 @@ async def _func(obj: Any, request: web.Request) -> None:
async for message in websocket:
if message.type == WSMsgType.TEXT:
if _receive_func:
await _receive_func(message.data)
try:
await _receive_func(message.data)
except Exception as e:
if not context.get('log_level') or context.get('log_level') in ['DEBUG']:
traceback.print_exception(e.__class__, e, e.__traceback__)
elif message.type == WSMsgType.ERROR:
e = ws.exception()
if not context.get('log_level') or context.get('log_level') in ['DEBUG']:
traceback.print_exception(e.__class__, e, e.__traceback__)
elif message.type == WSMsgType.CLOSED:
break # noqa
finally:
if _close_func:
await _close_func()
try:
await _close_func()
except Exception as e:
if not context.get('log_level') or context.get('log_level') in ['DEBUG']:
traceback.print_exception(e.__class__, e, e.__traceback__)
try:
await websocket.close()
except:
pass

try:
context['_http_open_websockets'].remove(websocket)
except:
pass

return await cls.request_handler(cls, obj, context, _func, 'GET', url)

Expand Down Expand Up @@ -550,6 +604,12 @@ async def func() -> web.Response:
async def stop_service(*args: Any, **kwargs: Any) -> None:
if stop_method:
await stop_method(*args, **kwargs)
open_websockets = context.get('_http_open_websockets', [])[:]
for websocket in open_websockets:
try:
await websocket.close()
except:
pass
server.close()
await app.shutdown()
if logger_handler:
Expand Down

0 comments on commit d182f2c

Please sign in to comment.