From d182f2c55bdef4625078ffa31ec57ee95fea7134 Mon Sep 17 00:00:00 2001 From: Carl Oscar Aaro Date: Fri, 2 Mar 2018 13:57:22 +0100 Subject: [PATCH] Close websockets on service exit + handle errors on close and receive methods --- tomodachi/transport/http.py | 68 ++++++++++++++++++++++++++++++++++--- 1 file changed, 64 insertions(+), 4 deletions(-) diff --git a/tomodachi/transport/http.py b/tomodachi/transport/http.py index 1c7f9dc43..3717ec250 100644 --- a/tomodachi/transport/http.py +++ b/tomodachi/transport/http.py @@ -357,6 +357,9 @@ async def _func(obj: Any, request: web.Request) -> None: websocket = web.WebSocketResponse() await websocket.prepare(request) + context['_http_open_websockets'] = context.get('_http_open_websockets', []) + context['_http_open_websockets'].append(websocket) + request.is_websocket = True request.websocket_uuid = str(uuid.uuid4()) @@ -378,8 +381,36 @@ async def _func(obj: Any, request: web.Request) -> None: for k, v in result.groupdict().items(): kwargs[k] = v - routine = func(*(obj, websocket,), **kwargs) - callback_functions = (await routine) if isinstance(routine, Awaitable) else routine # type: Optional[Union[Tuple, Callable]] + try: + routine = func(*(obj, websocket,), **kwargs) + callback_functions = (await routine) if isinstance(routine, Awaitable) else routine # type: Optional[Union[Tuple, Callable]] + except Exception as e: + if not context.get('log_level') or context.get('log_level') in ['DEBUG']: + traceback.print_exception(e.__class__, e, e.__traceback__) + try: + await websocket.close() + except: + pass + + try: + context['_http_open_websockets'].remove(websocket) + except: + pass + + logging.getLogger('transport.http').info('[{}] {} {} "{} {}{}" {} "{}" {}'.format( + RequestHandler.colorize_status('websocket', 500), + request.request_ip, + '"{}"'.format(request.auth.login.replace('"', '')) if request.auth and getattr(request.auth, 'login', None) else '-', + RequestHandler.colorize_status('ERROR', 500), + request.path, + '?{}'.format(request.query_string) if request.query_string else '', + request.websocket_uuid, + request.headers.get('User-Agent', '').replace('"', ''), + '-' + )) + + return + _receive_func = None _close_func = None @@ -395,10 +426,33 @@ async def _func(obj: Any, request: web.Request) -> None: async for message in websocket: if message.type == WSMsgType.TEXT: if _receive_func: - await _receive_func(message.data) + try: + await _receive_func(message.data) + except Exception as e: + if not context.get('log_level') or context.get('log_level') in ['DEBUG']: + traceback.print_exception(e.__class__, e, e.__traceback__) + elif message.type == WSMsgType.ERROR: + e = ws.exception() + if not context.get('log_level') or context.get('log_level') in ['DEBUG']: + traceback.print_exception(e.__class__, e, e.__traceback__) + elif message.type == WSMsgType.CLOSED: + break # noqa finally: if _close_func: - await _close_func() + try: + await _close_func() + except Exception as e: + if not context.get('log_level') or context.get('log_level') in ['DEBUG']: + traceback.print_exception(e.__class__, e, e.__traceback__) + try: + await websocket.close() + except: + pass + + try: + context['_http_open_websockets'].remove(websocket) + except: + pass return await cls.request_handler(cls, obj, context, _func, 'GET', url) @@ -550,6 +604,12 @@ async def func() -> web.Response: async def stop_service(*args: Any, **kwargs: Any) -> None: if stop_method: await stop_method(*args, **kwargs) + open_websockets = context.get('_http_open_websockets', [])[:] + for websocket in open_websockets: + try: + await websocket.close() + except: + pass server.close() await app.shutdown() if logger_handler: